1 #pragma once 2 #include <ATen/core/TensorBase.h> 3 #include <optional> 4 5 6 namespace at::cuda::detail { 7 TORCH_API void f8f8bf16_rowwise( 8 at::Tensor XQ, // FP8 9 at::Tensor WQ, // FP8 10 at::Tensor x_scale, // FP32 11 at::Tensor w_scale, // FP32 12 std::optional<at::Tensor> bias, // BF16 13 bool use_fast_accum, 14 at::Tensor& out); 15 } // at::cuda::detail 16