1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
6
7 // Determine if the architecture supports rowwise scaled mm
8 // Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571
9 #if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
10
11 #define BUILD_ROWWISE_FP8_KERNEL
12 #endif
13
14 #if defined(BUILD_ROWWISE_FP8_KERNEL)
15
16 // We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
nvrtc_cuTensorMapEncodeTiled(CUtensorMap * tensorMap,CUtensorMapDataType tensorDataType,cuuint32_t tensorRank,void * globalAddress,const cuuint64_t * globalDim,const cuuint64_t * globalStrides,const cuuint32_t * boxDim,const cuuint32_t * elementStrides,CUtensorMapInterleave interleave,CUtensorMapSwizzle swizzle,CUtensorMapL2promotion l2Promotion,CUtensorMapFloatOOBfill oobFill)17 static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
18 CUtensorMap* tensorMap,
19 CUtensorMapDataType tensorDataType,
20 cuuint32_t tensorRank,
21 void* globalAddress,
22 const cuuint64_t* globalDim,
23 const cuuint64_t* globalStrides,
24 const cuuint32_t* boxDim,
25 const cuuint32_t* elementStrides,
26 CUtensorMapInterleave interleave,
27 CUtensorMapSwizzle swizzle,
28 CUtensorMapL2promotion l2Promotion,
29 CUtensorMapFloatOOBfill oobFill) {
30 return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
31 tensorMap,
32 tensorDataType,
33 tensorRank,
34 globalAddress,
35 globalDim,
36 globalStrides,
37 boxDim,
38 elementStrides,
39 interleave,
40 swizzle,
41 l2Promotion,
42 oobFill);
43 }
44
45
46 #include <cutlass/core_io.h>
47 #include <cutlass/cutlass.h>
48 #include <cutlass/gemm/device/gemm.h>
49 #include <cutlass/half.h>
50 #include <cutlass/numeric_types.h>
51 #include <cutlass/trace.h>
52 #include <cutlass/util/host_tensor.h>
53
54 // Rename the global function symbol
55 #define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
56 #include <cute/tensor.hpp>
57 #undef cuTensorMapEncodeTiled
58 // Set everything back to normal
59
60 #include <cutlass/gemm/collective/collective_builder.hpp>
61 #include <cutlass/gemm/device/gemm_universal_adapter.h>
62 #include <cutlass/epilogue/collective/collective_builder.hpp>
63
64 #include <cute/atom/mma_atom.hpp>
65 #include <cutlass/gemm/dispatch_policy.hpp>
66 #include <cutlass/gemm/kernel/gemm_universal.hpp>
67 #include <cutlass/util/packed_stride.hpp>
68
69
70 namespace {
71
72 constexpr int kNumSMsForH100 = 132;
73
74 using DtypeScale = float;
75 using DtypeAccum = float;
76 using DtypeEpilogue = float;
77 using DtypeOutput = cutlass::bfloat16_t;
78
79 using Multiply = cutlass::epilogue::fusion::Sm90Compute<
80 cutlass::multiplies,
81 DtypeEpilogue,
82 DtypeEpilogue,
83 cutlass::FloatRoundStyle::round_to_nearest>;
84
85 using Add = cutlass::epilogue::fusion::Sm90Compute<
86 cutlass::plus,
87 DtypeEpilogue,
88 DtypeEpilogue,
89 cutlass::FloatRoundStyle::round_to_nearest>;
90
91 using Cast = cutlass::epilogue::fusion::Sm90Compute<
92 cutlass::epilogue::thread::Identity,
93 DtypeOutput,
94 DtypeEpilogue,
95 cutlass::FloatRoundStyle::round_to_nearest>;
96
97 template <bool PingPong, bool FastAccum>
98 struct Schedule;
99
100 template <>
101 struct Schedule</*PingPong=*/false, /*FastAccum=*/false> {
102 using type = cutlass::gemm::KernelTmaWarpSpecialized;
103 };
104
105 template <>
106 struct Schedule</*PingPong=*/true, /*FastAccum=*/false> {
107 using type = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
108 };
109
110 template <>
111 struct Schedule</*PingPong=*/false, /*FastAccum=*/true> {
112 using type = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
113 };
114
115 template <>
116 struct Schedule</*PingPong=*/true, /*FastAccum=*/true> {
117 using type = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
118 };
119
ceildiv(int a,int b)120 int ceildiv(int a, int b) {
121 return (a + b - 1) / b;
122 }
123
round_up_to_nearest_multiple(int a,int b)124 int round_up_to_nearest_multiple(int a, int b) {
125 return ceildiv(a, b) * b;
126 }
127
128 // Cutlass rowwise kernel
129 template <
130 typename TileShape,
131 typename ClusterShape,
132 typename PingPong,
133 typename Transposed,
134 typename FastAccum,
135 typename DtypeA,
136 typename DtypeB,
137 typename DtypeBias>
f8f8bf16_rowwise_impl(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)138 void f8f8bf16_rowwise_impl(
139 at::Tensor XQ, // FP8
140 at::Tensor WQ, // FP8
141 at::Tensor x_scale,
142 at::Tensor w_scale,
143 std::optional<at::Tensor> bias,
144 at::Tensor out) {
145 int M = XQ.size(0);
146 int N = WQ.size(1);
147 int K = XQ.size(1);
148
149 // Workaround for https://github.com/pytorch/pytorch/issues/133334.
150 if (M % 256 > 0) {
151 int padded_M = ((M - 1) / 256 + 1) * 256;
152 at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
153 padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
154 .copy_(std::move(x_scale));
155 x_scale = std::move(padded_x_scale);
156 }
157
158 using LayoutInputA = cutlass::layout::RowMajor;
159 constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
160
161 using LayoutInputB = cutlass::layout::ColumnMajor;
162 constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
163
164 using LayoutOutput = std::conditional_t<
165 Transposed::value,
166 cutlass::layout::ColumnMajor,
167 cutlass::layout::RowMajor>;
168 constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
169
170 // Tag indicating the minimum SM that supports the intended feature
171 using ArchTag = cutlass::arch::Sm90;
172 using OperatorClass = cutlass::arch::OpClassTensorOp;
173
174 // Implement rowwise scaling epilogue.
175 constexpr int ColBroadcastStages = 0;
176 constexpr int RowBroadcastStages = PingPong::value ? 2 : 1;
177
178 using XScale = cutlass::epilogue::fusion::
179 Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
180
181 using WScale = cutlass::epilogue::fusion::
182 Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
183
184 using Bias = std::conditional_t<
185 Transposed::value,
186 cutlass::epilogue::fusion::
187 Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
188 cutlass::epilogue::fusion::
189 Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
190
191 using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
192
193 using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
194 Cast,
195 cutlass::epilogue::fusion::Sm90EVT<
196 Add,
197 Bias,
198 cutlass::epilogue::fusion::Sm90EVT<
199 Multiply,
200 XScale,
201 cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>>>;
202
203 using CollectiveEpilogue =
204 typename cutlass::epilogue::collective::CollectiveBuilder<
205 ArchTag,
206 OperatorClass,
207 TileShape,
208 ClusterShape,
209 cutlass::epilogue::collective::EpilogueTileAuto,
210 DtypeAccum,
211 DtypeEpilogue,
212 DtypeOutput,
213 LayoutOutput,
214 AlignmentOutput,
215 DtypeOutput,
216 LayoutOutput,
217 AlignmentOutput,
218 cutlass::epilogue::TmaWarpSpecialized,
219 EpilogueEVT>::CollectiveOp;
220
221 using CollectiveMainloop =
222 typename cutlass::gemm::collective::CollectiveBuilder<
223 ArchTag,
224 OperatorClass,
225 DtypeA,
226 LayoutInputA,
227 AlignmentInputA,
228 DtypeB,
229 LayoutInputB,
230 AlignmentInputB,
231 DtypeAccum,
232 TileShape,
233 ClusterShape,
234 cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
235 sizeof(typename CollectiveEpilogue::SharedStorage))>,
236 typename Schedule<PingPong::value, FastAccum::value>::type>::
237 CollectiveOp;
238
239 using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
240 cute::Shape<int, int, int>,
241 CollectiveMainloop,
242 CollectiveEpilogue>;
243
244 using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
245
246 using StrideInputA = typename Gemm::GemmKernel::StrideA;
247 using StrideInputB = typename Gemm::GemmKernel::StrideB;
248 using StrideOutput = typename Gemm::GemmKernel::StrideC;
249
250 StrideInputA stride_a = cutlass::make_cute_packed_stride(
251 StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
252 StrideInputB stride_b = cutlass::make_cute_packed_stride(
253 StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
254 StrideOutput stride_output = cutlass::make_cute_packed_stride(
255 StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
256
257 typename Gemm::Arguments arguments{
258 cutlass::gemm::GemmUniversalMode::kGemm,
259 {M, N, K},
260 {reinterpret_cast<DtypeA*>(XQ.data_ptr()),
261 stride_a,
262 reinterpret_cast<DtypeB*>(WQ.data_ptr()),
263 stride_b},
264 {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
265 : nullptr},
266 {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
267 {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}},
268 reinterpret_cast<DtypeOutput*>(out.data_ptr()),
269 stride_output,
270 reinterpret_cast<DtypeOutput*>(out.data_ptr()),
271 stride_output}};
272
273 Gemm gemm;
274
275 // Using the arguments, query for extra workspace required for matrix
276 // multiplication computation
277 size_t workspace_size = Gemm::get_workspace_size(arguments);
278
279 // Allocate workspace memory
280 auto workspace = XQ.new_empty(
281 {static_cast<int64_t>(workspace_size)},
282 at::TensorOptions().dtype(at::kByte));
283
284 // Check the problem size is supported or not
285 cutlass::Status status = gemm.can_implement(arguments);
286 if (status != cutlass::Status::kSuccess) {
287 throw std::runtime_error("cutlass cannot implement");
288 }
289
290 // Initialize CUTLASS kernel with arguments and workspace pointer
291 status = gemm.initialize(arguments, workspace.data_ptr());
292 if (status != cutlass::Status::kSuccess) {
293 throw std::runtime_error("cutlass cannot initialize");
294 }
295
296 status = gemm(at::cuda::getCurrentCUDAStream());
297 if (status != cutlass::Status::kSuccess) {
298 throw std::runtime_error(
299 std::string("cutlass cannot run") +
300 cutlass::cutlassGetStatusString(status));
301 }
302 C10_CUDA_KERNEL_LAUNCH_CHECK();
303 }
304
305 template <typename ClusterShape, typename... Types>
dispatch_fp8_rowwise_kernel_on_tile_size(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)306 void dispatch_fp8_rowwise_kernel_on_tile_size(
307 at::Tensor XQ,
308 at::Tensor WQ,
309 at::Tensor x_scale,
310 at::Tensor w_scale,
311 std::optional<at::Tensor> bias,
312 at::Tensor out) {
313 int M = XQ.size(0);
314 int N = WQ.size(1);
315
316 // We prefer to use smaller tiles (less wasted compute in case of padding),
317 // but if this causes us to have more CUDA blocks than there are SMs on the
318 // GPU then we'll hit wave quantization, hence we'll switch to larger tiles.
319 if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) *
320 ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <=
321 kNumSMsForH100 / cute::size(ClusterShape{})) {
322 return f8f8bf16_rowwise_impl<
323 /*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
324 ClusterShape,
325 /*PingPong=*/std::false_type,
326 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
327 } else {
328 return f8f8bf16_rowwise_impl<
329 /*TileShape=*/cute::Shape<cute::_128, cute::_128, cute::_128>,
330 ClusterShape,
331 /*PingPong=*/std::true_type,
332 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
333 }
334 }
335
336 template <
337 typename ClusterShape,
338 typename Transposed,
339 typename FastAccum,
340 typename DtypeA,
341 typename DtypeB,
342 typename DtypeBias>
handle_transposition(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)343 void handle_transposition(
344 at::Tensor XQ,
345 at::Tensor WQ,
346 at::Tensor x_scale,
347 at::Tensor w_scale,
348 std::optional<at::Tensor> bias,
349 at::Tensor out) {
350 if constexpr (!Transposed::value) {
351 dispatch_fp8_rowwise_kernel_on_tile_size<
352 ClusterShape,
353 Transposed,
354 FastAccum,
355 DtypeA,
356 DtypeB,
357 DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out);
358 } else {
359 dispatch_fp8_rowwise_kernel_on_tile_size<
360 ClusterShape,
361 Transposed,
362 FastAccum,
363 DtypeB,
364 DtypeA,
365 DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t());
366 }
367 }
368
369 template <typename... Types>
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)370 void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(
371 at::Tensor XQ,
372 at::Tensor WQ,
373 at::Tensor x_scale,
374 at::Tensor w_scale,
375 std::optional<at::Tensor> bias,
376 at::Tensor out) {
377 int M = XQ.size(0);
378 int N = WQ.size(1);
379
380 // All the tiles we use have sizes which are multiples of 64, hence any
381 // non-multiple of 64 will get padded anyways. Let's round up to simplify.
382 M = round_up_to_nearest_multiple(M, 64);
383 N = round_up_to_nearest_multiple(N, 64);
384
385 // Small/skinny shapes with odd multiples of 64.
386 if (M == 64 && N >= 3072) {
387 return handle_transposition<
388 /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
389 /*Transposed=*/std::false_type,
390 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
391 }
392 if (N == 64 && M >= 3072) {
393 return handle_transposition<
394 /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
395 /*Transposed=*/std::true_type,
396 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
397 }
398 if (M == 192 && N >= 4096) {
399 return handle_transposition<
400 /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
401 /*Transposed=*/std::true_type,
402 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
403 }
404 if (N == 192 && M >= 4096) {
405 return handle_transposition<
406 /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
407 /*Transposed=*/std::false_type,
408 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
409 }
410
411 // Now to odd multiples of 128 (but only if not too large).
412 if (M * N <= 4096 * 4096) {
413 if (M % 256 > 0 && N % 256 == 0) {
414 return handle_transposition<
415 /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
416 /*Transposed=*/std::true_type,
417 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
418 }
419 if (N % 256 > 0 && M % 256 == 0) {
420 return handle_transposition<
421 /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
422 /*Transposed=*/std::false_type,
423 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
424 }
425 }
426 if (M % 256 > 0 && N % 256 > 0) {
427 if ((M <= N) ^ (M * N <= 1024 * 1024)) {
428 return handle_transposition<
429 /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
430 /*Transposed=*/std::true_type,
431 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
432 } else {
433 return handle_transposition<
434 /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
435 /*Transposed=*/std::false_type,
436 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
437 }
438 }
439
440 // General case for large tensors.
441 if ((M <= N) ^ (M >= 2048 && N >= 2048)) {
442 return handle_transposition<
443 /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
444 /*Transposed=*/std::true_type,
445 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
446 } else {
447 return handle_transposition<
448 /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
449 /*Transposed=*/std::true_type,
450 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
451 }
452 }
453
454 template <typename... Types>
dispatch_fp8_rowwise_kernel_on_fast_accum(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor out)455 void dispatch_fp8_rowwise_kernel_on_fast_accum(
456 at::Tensor XQ,
457 at::Tensor WQ,
458 at::Tensor x_scale,
459 at::Tensor w_scale,
460 std::optional<at::Tensor> bias,
461 bool use_fast_accum,
462 at::Tensor out) {
463 if (use_fast_accum) {
464 dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
465 std::true_type,
466 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
467 } else {
468 dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
469 std::false_type,
470 Types...>(XQ, WQ, x_scale, w_scale, bias, out);
471 }
472 }
473
474 template <typename... Types>
dispatch_fp8_rowwise_kernel_on_input_dtypes(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor out)475 void dispatch_fp8_rowwise_kernel_on_input_dtypes(
476 at::Tensor XQ,
477 at::Tensor WQ,
478 at::Tensor x_scale,
479 at::Tensor w_scale,
480 std::optional<at::Tensor> bias,
481 bool use_fast_accum,
482 at::Tensor out) {
483 if (XQ.dtype() == at::kFloat8_e5m2) {
484 dispatch_fp8_rowwise_kernel_on_fast_accum<
485 cutlass::float_e5m2_t,
486 cutlass::float_e4m3_t,
487 Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
488 } else {
489 dispatch_fp8_rowwise_kernel_on_fast_accum<
490 cutlass::float_e4m3_t,
491 cutlass::float_e4m3_t,
492 Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
493 }
494 }
495
dispatch_fp8_rowwise_kernel_on_bias_dtype(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor out)496 void dispatch_fp8_rowwise_kernel_on_bias_dtype(
497 at::Tensor XQ,
498 at::Tensor WQ,
499 at::Tensor x_scale,
500 at::Tensor w_scale,
501 std::optional<at::Tensor> bias,
502 bool use_fast_accum,
503 at::Tensor out) {
504 if (bias.has_value() && bias->dtype() == at::kBFloat16) {
505 dispatch_fp8_rowwise_kernel_on_input_dtypes<cutlass::bfloat16_t>(
506 XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
507 } else {
508 dispatch_fp8_rowwise_kernel_on_input_dtypes<float>(
509 XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
510 }
511 }
512
check_inputs(const at::Tensor & a,const at::Tensor & b,const at::Tensor & scale_a,const at::Tensor & scale_b,const std::optional<at::Tensor> & bias,const at::Tensor & out)513 void check_inputs(
514 const at::Tensor& a,
515 const at::Tensor& b,
516 const at::Tensor& scale_a,
517 const at::Tensor& scale_b,
518 const std::optional<at::Tensor>& bias,
519 const at::Tensor& out) {
520 TORCH_CHECK(a.is_cuda());
521 TORCH_CHECK(a.device() == b.device());
522 TORCH_CHECK(scale_a.device() == a.device());
523 TORCH_CHECK(scale_b.device() == b.device());
524
525 TORCH_CHECK(a.dtype() == at::kFloat8_e4m3fn || a.dtype() == at::kFloat8_e5m2);
526 TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
527 TORCH_CHECK(scale_a.dtype() == at::kFloat);
528 TORCH_CHECK(scale_b.dtype() == at::kFloat);
529
530 TORCH_CHECK(a.dim() == 2);
531 TORCH_CHECK(b.dim() == 2);
532 TORCH_CHECK(a.size(1) == b.size(0));
533 TORCH_CHECK(scale_a.dim() == 2);
534 TORCH_CHECK(scale_b.dim() == 2);
535 TORCH_CHECK(scale_a.size(0) == a.size(0));
536 TORCH_CHECK(scale_a.size(1) == 1);
537 TORCH_CHECK(scale_b.size(0) == 1);
538 TORCH_CHECK(scale_b.size(1) == b.size(1));
539
540 TORCH_CHECK(a.stride(1) == 1);
541 TORCH_CHECK(a.stride(0) >= a.size(1));
542 TORCH_CHECK(b.stride(0) == 1);
543 TORCH_CHECK(b.stride(1) >= b.size(0));
544 TORCH_CHECK(scale_a.stride(0) == 1);
545 TORCH_CHECK(scale_b.stride(1) == 1);
546
547 if (bias.has_value()) {
548 TORCH_CHECK(bias->device() == b.device());
549 TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16);
550 TORCH_CHECK(bias->dim() == 1);
551 TORCH_CHECK(bias->size(0) == b.size(1));
552 TORCH_CHECK(bias->stride(0) == 1);
553 }
554
555 TORCH_CHECK(out.device() == a.device());
556 TORCH_CHECK(out.dtype() == at::kBFloat16);
557 TORCH_CHECK(out.dim() == 2);
558 TORCH_CHECK(out.size(0) == a.size(0));
559 TORCH_CHECK(out.size(1) == b.size(1));
560 TORCH_CHECK(out.stride(1) == 1);
561 TORCH_CHECK(out.stride(0) >= out.size(1));
562 }
563
564 } // namespace
565
566 #endif // !defined(USE_ROCM)
567
568 namespace at::cuda::detail {
f8f8bf16_rowwise(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor & out)569 void f8f8bf16_rowwise(
570 at::Tensor XQ, // FP8
571 at::Tensor WQ, // FP8
572 at::Tensor x_scale, // FP32
573 at::Tensor w_scale, // FP32
574 std::optional<at::Tensor> bias, // BF16
575 bool use_fast_accum,
576 at::Tensor& out) {
577 #if defined(BUILD_ROWWISE_FP8_KERNEL)
578 check_inputs(XQ, WQ, x_scale, w_scale, bias, out);
579
580 dispatch_fp8_rowwise_kernel_on_bias_dtype(
581 XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
582 #else // BUILD_ROWWISE_FP8_KERNEL
583 TORCH_CHECK(
584 false, "Rowwise scaling is not currenlty supported on your device");
585 #endif
586 }
587
588 } // namespace at::cuda::detail
589