xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/RowwiseScaledMM.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
6 
7 // Determine if the architecture supports rowwise scaled mm
8 // Currenlty failing on windows with: https://github.com/NVIDIA/cutlass/issues/1571
9 #if !defined(USE_ROCM) && !defined(_WIN32) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
10 
11 #define BUILD_ROWWISE_FP8_KERNEL
12 #endif
13 
14 #if defined(BUILD_ROWWISE_FP8_KERNEL)
15 
16 // We are going to override the cuTensorMapEncodeTiled driver api with our lazy loader
nvrtc_cuTensorMapEncodeTiled(CUtensorMap * tensorMap,CUtensorMapDataType tensorDataType,cuuint32_t tensorRank,void * globalAddress,const cuuint64_t * globalDim,const cuuint64_t * globalStrides,const cuuint32_t * boxDim,const cuuint32_t * elementStrides,CUtensorMapInterleave interleave,CUtensorMapSwizzle swizzle,CUtensorMapL2promotion l2Promotion,CUtensorMapFloatOOBfill oobFill)17 static CUresult CUDAAPI nvrtc_cuTensorMapEncodeTiled(
18     CUtensorMap* tensorMap,
19     CUtensorMapDataType tensorDataType,
20     cuuint32_t tensorRank,
21     void* globalAddress,
22     const cuuint64_t* globalDim,
23     const cuuint64_t* globalStrides,
24     const cuuint32_t* boxDim,
25     const cuuint32_t* elementStrides,
26     CUtensorMapInterleave interleave,
27     CUtensorMapSwizzle swizzle,
28     CUtensorMapL2promotion l2Promotion,
29     CUtensorMapFloatOOBfill oobFill) {
30   return at::globalContext().getNVRTC().cuTensorMapEncodeTiled(
31       tensorMap,
32       tensorDataType,
33       tensorRank,
34       globalAddress,
35       globalDim,
36       globalStrides,
37       boxDim,
38       elementStrides,
39       interleave,
40       swizzle,
41       l2Promotion,
42       oobFill);
43 }
44 
45 
46 #include <cutlass/core_io.h>
47 #include <cutlass/cutlass.h>
48 #include <cutlass/gemm/device/gemm.h>
49 #include <cutlass/half.h>
50 #include <cutlass/numeric_types.h>
51 #include <cutlass/trace.h>
52 #include <cutlass/util/host_tensor.h>
53 
54 // Rename the global function symbol
55 #define cuTensorMapEncodeTiled nvrtc_cuTensorMapEncodeTiled
56 #include <cute/tensor.hpp>
57 #undef cuTensorMapEncodeTiled
58 // Set everything back to normal
59 
60 #include <cutlass/gemm/collective/collective_builder.hpp>
61 #include <cutlass/gemm/device/gemm_universal_adapter.h>
62 #include <cutlass/epilogue/collective/collective_builder.hpp>
63 
64 #include <cute/atom/mma_atom.hpp>
65 #include <cutlass/gemm/dispatch_policy.hpp>
66 #include <cutlass/gemm/kernel/gemm_universal.hpp>
67 #include <cutlass/util/packed_stride.hpp>
68 
69 
70 namespace {
71 
72 constexpr int kNumSMsForH100 = 132;
73 
74 using DtypeScale = float;
75 using DtypeAccum = float;
76 using DtypeEpilogue = float;
77 using DtypeOutput = cutlass::bfloat16_t;
78 
79 using Multiply = cutlass::epilogue::fusion::Sm90Compute<
80     cutlass::multiplies,
81     DtypeEpilogue,
82     DtypeEpilogue,
83     cutlass::FloatRoundStyle::round_to_nearest>;
84 
85 using Add = cutlass::epilogue::fusion::Sm90Compute<
86     cutlass::plus,
87     DtypeEpilogue,
88     DtypeEpilogue,
89     cutlass::FloatRoundStyle::round_to_nearest>;
90 
91 using Cast = cutlass::epilogue::fusion::Sm90Compute<
92     cutlass::epilogue::thread::Identity,
93     DtypeOutput,
94     DtypeEpilogue,
95     cutlass::FloatRoundStyle::round_to_nearest>;
96 
97 template <bool PingPong, bool FastAccum>
98 struct Schedule;
99 
100 template <>
101 struct Schedule</*PingPong=*/false, /*FastAccum=*/false> {
102   using type = cutlass::gemm::KernelTmaWarpSpecialized;
103 };
104 
105 template <>
106 struct Schedule</*PingPong=*/true, /*FastAccum=*/false> {
107   using type = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
108 };
109 
110 template <>
111 struct Schedule</*PingPong=*/false, /*FastAccum=*/true> {
112   using type = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
113 };
114 
115 template <>
116 struct Schedule</*PingPong=*/true, /*FastAccum=*/true> {
117   using type = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
118 };
119 
ceildiv(int a,int b)120 int ceildiv(int a, int b) {
121   return (a + b - 1) / b;
122 }
123 
round_up_to_nearest_multiple(int a,int b)124 int round_up_to_nearest_multiple(int a, int b) {
125   return ceildiv(a, b) * b;
126 }
127 
128 // Cutlass rowwise kernel
129 template <
130     typename TileShape,
131     typename ClusterShape,
132     typename PingPong,
133     typename Transposed,
134     typename FastAccum,
135     typename DtypeA,
136     typename DtypeB,
137     typename DtypeBias>
f8f8bf16_rowwise_impl(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)138 void f8f8bf16_rowwise_impl(
139     at::Tensor XQ, // FP8
140     at::Tensor WQ, // FP8
141     at::Tensor x_scale,
142     at::Tensor w_scale,
143     std::optional<at::Tensor> bias,
144     at::Tensor out) {
145   int M = XQ.size(0);
146   int N = WQ.size(1);
147   int K = XQ.size(1);
148 
149   // Workaround for https://github.com/pytorch/pytorch/issues/133334.
150   if (M % 256 > 0) {
151     int padded_M = ((M - 1) / 256 + 1) * 256;
152     at::Tensor padded_x_scale = x_scale.new_empty({padded_M, 1});
153     padded_x_scale.slice(/*dim=*/0, /*start=*/0, /*end=*/M)
154         .copy_(std::move(x_scale));
155     x_scale = std::move(padded_x_scale);
156   }
157 
158   using LayoutInputA = cutlass::layout::RowMajor;
159   constexpr int AlignmentInputA = 16 / sizeof(DtypeA);
160 
161   using LayoutInputB = cutlass::layout::ColumnMajor;
162   constexpr int AlignmentInputB = 16 / sizeof(DtypeB);
163 
164   using LayoutOutput = std::conditional_t<
165       Transposed::value,
166       cutlass::layout::ColumnMajor,
167       cutlass::layout::RowMajor>;
168   constexpr int AlignmentOutput = 16 / sizeof(DtypeOutput);
169 
170   // Tag indicating the minimum SM that supports the intended feature
171   using ArchTag = cutlass::arch::Sm90;
172   using OperatorClass = cutlass::arch::OpClassTensorOp;
173 
174   // Implement rowwise scaling epilogue.
175   constexpr int ColBroadcastStages = 0;
176   constexpr int RowBroadcastStages = PingPong::value ? 2 : 1;
177 
178   using XScale = cutlass::epilogue::fusion::
179       Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeScale>;
180 
181   using WScale = cutlass::epilogue::fusion::
182       Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeScale>;
183 
184   using Bias = std::conditional_t<
185       Transposed::value,
186       cutlass::epilogue::fusion::
187           Sm90ColBroadcast<ColBroadcastStages, TileShape, DtypeBias>,
188       cutlass::epilogue::fusion::
189           Sm90RowBroadcast<RowBroadcastStages, TileShape, DtypeBias>>;
190 
191   using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
192 
193   using EpilogueEVT = cutlass::epilogue::fusion::Sm90EVT<
194       Cast,
195       cutlass::epilogue::fusion::Sm90EVT<
196           Add,
197           Bias,
198           cutlass::epilogue::fusion::Sm90EVT<
199               Multiply,
200               XScale,
201               cutlass::epilogue::fusion::Sm90EVT<Multiply, WScale, Accum>>>>;
202 
203   using CollectiveEpilogue =
204       typename cutlass::epilogue::collective::CollectiveBuilder<
205           ArchTag,
206           OperatorClass,
207           TileShape,
208           ClusterShape,
209           cutlass::epilogue::collective::EpilogueTileAuto,
210           DtypeAccum,
211           DtypeEpilogue,
212           DtypeOutput,
213           LayoutOutput,
214           AlignmentOutput,
215           DtypeOutput,
216           LayoutOutput,
217           AlignmentOutput,
218           cutlass::epilogue::TmaWarpSpecialized,
219           EpilogueEVT>::CollectiveOp;
220 
221   using CollectiveMainloop =
222       typename cutlass::gemm::collective::CollectiveBuilder<
223           ArchTag,
224           OperatorClass,
225           DtypeA,
226           LayoutInputA,
227           AlignmentInputA,
228           DtypeB,
229           LayoutInputB,
230           AlignmentInputB,
231           DtypeAccum,
232           TileShape,
233           ClusterShape,
234           cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
235               sizeof(typename CollectiveEpilogue::SharedStorage))>,
236           typename Schedule<PingPong::value, FastAccum::value>::type>::
237           CollectiveOp;
238 
239   using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
240       cute::Shape<int, int, int>,
241       CollectiveMainloop,
242       CollectiveEpilogue>;
243 
244   using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
245 
246   using StrideInputA = typename Gemm::GemmKernel::StrideA;
247   using StrideInputB = typename Gemm::GemmKernel::StrideB;
248   using StrideOutput = typename Gemm::GemmKernel::StrideC;
249 
250   StrideInputA stride_a = cutlass::make_cute_packed_stride(
251       StrideInputA{}, cute::make_shape(M, static_cast<int>(XQ.stride(0)), 1));
252   StrideInputB stride_b = cutlass::make_cute_packed_stride(
253       StrideInputB{}, cute::make_shape(N, static_cast<int>(WQ.stride(1)), 1));
254   StrideOutput stride_output = cutlass::make_cute_packed_stride(
255       StrideOutput{}, cute::make_shape(M, static_cast<int>(out.stride(0)), 1));
256 
257   typename Gemm::Arguments arguments{
258       cutlass::gemm::GemmUniversalMode::kGemm,
259       {M, N, K},
260       {reinterpret_cast<DtypeA*>(XQ.data_ptr()),
261        stride_a,
262        reinterpret_cast<DtypeB*>(WQ.data_ptr()),
263        stride_b},
264       {{{{bias.has_value() ? reinterpret_cast<DtypeBias*>(bias->data_ptr())
265                            : nullptr},
266          {{reinterpret_cast<DtypeScale*>(x_scale.data_ptr())},
267           {{reinterpret_cast<DtypeScale*>(w_scale.data_ptr())}}}}},
268        reinterpret_cast<DtypeOutput*>(out.data_ptr()),
269        stride_output,
270        reinterpret_cast<DtypeOutput*>(out.data_ptr()),
271        stride_output}};
272 
273   Gemm gemm;
274 
275   // Using the arguments, query for extra workspace required for matrix
276   // multiplication computation
277   size_t workspace_size = Gemm::get_workspace_size(arguments);
278 
279   // Allocate workspace memory
280   auto workspace = XQ.new_empty(
281       {static_cast<int64_t>(workspace_size)},
282       at::TensorOptions().dtype(at::kByte));
283 
284   // Check the problem size is supported or not
285   cutlass::Status status = gemm.can_implement(arguments);
286   if (status != cutlass::Status::kSuccess) {
287     throw std::runtime_error("cutlass cannot implement");
288   }
289 
290   // Initialize CUTLASS kernel with arguments and workspace pointer
291   status = gemm.initialize(arguments, workspace.data_ptr());
292   if (status != cutlass::Status::kSuccess) {
293     throw std::runtime_error("cutlass cannot initialize");
294   }
295 
296   status = gemm(at::cuda::getCurrentCUDAStream());
297   if (status != cutlass::Status::kSuccess) {
298     throw std::runtime_error(
299         std::string("cutlass cannot run") +
300         cutlass::cutlassGetStatusString(status));
301   }
302   C10_CUDA_KERNEL_LAUNCH_CHECK();
303 }
304 
305 template <typename ClusterShape, typename... Types>
dispatch_fp8_rowwise_kernel_on_tile_size(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)306 void dispatch_fp8_rowwise_kernel_on_tile_size(
307     at::Tensor XQ,
308     at::Tensor WQ,
309     at::Tensor x_scale,
310     at::Tensor w_scale,
311     std::optional<at::Tensor> bias,
312     at::Tensor out) {
313   int M = XQ.size(0);
314   int N = WQ.size(1);
315 
316   // We prefer to use smaller tiles (less wasted compute in case of padding),
317   // but if this causes us to have more CUDA blocks than there are SMs on the
318   // GPU then we'll hit wave quantization, hence we'll switch to larger tiles.
319   if (ceildiv(M, 64 * cute::get<0>(ClusterShape{})) *
320           ceildiv(N, 128 * cute::get<1>(ClusterShape{})) <=
321       kNumSMsForH100 / cute::size(ClusterShape{})) {
322     return f8f8bf16_rowwise_impl<
323         /*TileShape=*/cute::Shape<cute::_64, cute::_128, cute::_128>,
324         ClusterShape,
325         /*PingPong=*/std::false_type,
326         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
327   } else {
328     return f8f8bf16_rowwise_impl<
329         /*TileShape=*/cute::Shape<cute::_128, cute::_128, cute::_128>,
330         ClusterShape,
331         /*PingPong=*/std::true_type,
332         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
333   }
334 }
335 
336 template <
337     typename ClusterShape,
338     typename Transposed,
339     typename FastAccum,
340     typename DtypeA,
341     typename DtypeB,
342     typename DtypeBias>
handle_transposition(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)343 void handle_transposition(
344     at::Tensor XQ,
345     at::Tensor WQ,
346     at::Tensor x_scale,
347     at::Tensor w_scale,
348     std::optional<at::Tensor> bias,
349     at::Tensor out) {
350   if constexpr (!Transposed::value) {
351     dispatch_fp8_rowwise_kernel_on_tile_size<
352         ClusterShape,
353         Transposed,
354         FastAccum,
355         DtypeA,
356         DtypeB,
357         DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out);
358   } else {
359     dispatch_fp8_rowwise_kernel_on_tile_size<
360         ClusterShape,
361         Transposed,
362         FastAccum,
363         DtypeB,
364         DtypeA,
365         DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t());
366   }
367 }
368 
369 template <typename... Types>
dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,at::Tensor out)370 void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose(
371     at::Tensor XQ,
372     at::Tensor WQ,
373     at::Tensor x_scale,
374     at::Tensor w_scale,
375     std::optional<at::Tensor> bias,
376     at::Tensor out) {
377   int M = XQ.size(0);
378   int N = WQ.size(1);
379 
380   // All the tiles we use have sizes which are multiples of 64, hence any
381   // non-multiple of 64 will get padded anyways. Let's round up to simplify.
382   M = round_up_to_nearest_multiple(M, 64);
383   N = round_up_to_nearest_multiple(N, 64);
384 
385   // Small/skinny shapes with odd multiples of 64.
386   if (M == 64 && N >= 3072) {
387     return handle_transposition<
388         /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
389         /*Transposed=*/std::false_type,
390         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
391   }
392   if (N == 64 && M >= 3072) {
393     return handle_transposition<
394         /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
395         /*Transposed=*/std::true_type,
396         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
397   }
398   if (M == 192 && N >= 4096) {
399     return handle_transposition<
400         /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
401         /*Transposed=*/std::true_type,
402         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
403   }
404   if (N == 192 && M >= 4096) {
405     return handle_transposition<
406         /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
407         /*Transposed=*/std::false_type,
408         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
409   }
410 
411   // Now to odd multiples of 128 (but only if not too large).
412   if (M * N <= 4096 * 4096) {
413     if (M % 256 > 0 && N % 256 == 0) {
414       return handle_transposition<
415           /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
416           /*Transposed=*/std::true_type,
417           Types...>(XQ, WQ, x_scale, w_scale, bias, out);
418     }
419     if (N % 256 > 0 && M % 256 == 0) {
420       return handle_transposition<
421           /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
422           /*Transposed=*/std::false_type,
423           Types...>(XQ, WQ, x_scale, w_scale, bias, out);
424     }
425   }
426   if (M % 256 > 0 && N % 256 > 0) {
427     if ((M <= N) ^ (M * N <= 1024 * 1024)) {
428       return handle_transposition<
429           /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
430           /*Transposed=*/std::true_type,
431           Types...>(XQ, WQ, x_scale, w_scale, bias, out);
432     } else {
433       return handle_transposition<
434           /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
435           /*Transposed=*/std::false_type,
436           Types...>(XQ, WQ, x_scale, w_scale, bias, out);
437     }
438   }
439 
440   // General case for large tensors.
441   if ((M <= N) ^ (M >= 2048 && N >= 2048)) {
442     return handle_transposition<
443         /*ClusterShape=*/cute::Shape<cute::_1, cute::_2, cute::_1>,
444         /*Transposed=*/std::true_type,
445         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
446   } else {
447     return handle_transposition<
448         /*ClusterShape=*/cute::Shape<cute::_2, cute::_1, cute::_1>,
449         /*Transposed=*/std::true_type,
450         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
451   }
452 }
453 
454 template <typename... Types>
dispatch_fp8_rowwise_kernel_on_fast_accum(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor out)455 void dispatch_fp8_rowwise_kernel_on_fast_accum(
456     at::Tensor XQ,
457     at::Tensor WQ,
458     at::Tensor x_scale,
459     at::Tensor w_scale,
460     std::optional<at::Tensor> bias,
461     bool use_fast_accum,
462     at::Tensor out) {
463   if (use_fast_accum) {
464     dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
465         std::true_type,
466         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
467   } else {
468     dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose<
469         std::false_type,
470         Types...>(XQ, WQ, x_scale, w_scale, bias, out);
471   }
472 }
473 
474 template <typename... Types>
dispatch_fp8_rowwise_kernel_on_input_dtypes(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor out)475 void dispatch_fp8_rowwise_kernel_on_input_dtypes(
476     at::Tensor XQ,
477     at::Tensor WQ,
478     at::Tensor x_scale,
479     at::Tensor w_scale,
480     std::optional<at::Tensor> bias,
481     bool use_fast_accum,
482     at::Tensor out) {
483   if (XQ.dtype() == at::kFloat8_e5m2) {
484     dispatch_fp8_rowwise_kernel_on_fast_accum<
485         cutlass::float_e5m2_t,
486         cutlass::float_e4m3_t,
487         Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
488   } else {
489     dispatch_fp8_rowwise_kernel_on_fast_accum<
490         cutlass::float_e4m3_t,
491         cutlass::float_e4m3_t,
492         Types...>(XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
493   }
494 }
495 
dispatch_fp8_rowwise_kernel_on_bias_dtype(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor out)496 void dispatch_fp8_rowwise_kernel_on_bias_dtype(
497     at::Tensor XQ,
498     at::Tensor WQ,
499     at::Tensor x_scale,
500     at::Tensor w_scale,
501     std::optional<at::Tensor> bias,
502     bool use_fast_accum,
503     at::Tensor out) {
504   if (bias.has_value() && bias->dtype() == at::kBFloat16) {
505     dispatch_fp8_rowwise_kernel_on_input_dtypes<cutlass::bfloat16_t>(
506         XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
507   } else {
508     dispatch_fp8_rowwise_kernel_on_input_dtypes<float>(
509         XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
510   }
511 }
512 
check_inputs(const at::Tensor & a,const at::Tensor & b,const at::Tensor & scale_a,const at::Tensor & scale_b,const std::optional<at::Tensor> & bias,const at::Tensor & out)513 void check_inputs(
514     const at::Tensor& a,
515     const at::Tensor& b,
516     const at::Tensor& scale_a,
517     const at::Tensor& scale_b,
518     const std::optional<at::Tensor>& bias,
519     const at::Tensor& out) {
520   TORCH_CHECK(a.is_cuda());
521   TORCH_CHECK(a.device() == b.device());
522   TORCH_CHECK(scale_a.device() == a.device());
523   TORCH_CHECK(scale_b.device() == b.device());
524 
525   TORCH_CHECK(a.dtype() == at::kFloat8_e4m3fn || a.dtype() == at::kFloat8_e5m2);
526   TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn);
527   TORCH_CHECK(scale_a.dtype() == at::kFloat);
528   TORCH_CHECK(scale_b.dtype() == at::kFloat);
529 
530   TORCH_CHECK(a.dim() == 2);
531   TORCH_CHECK(b.dim() == 2);
532   TORCH_CHECK(a.size(1) == b.size(0));
533   TORCH_CHECK(scale_a.dim() == 2);
534   TORCH_CHECK(scale_b.dim() == 2);
535   TORCH_CHECK(scale_a.size(0) == a.size(0));
536   TORCH_CHECK(scale_a.size(1) == 1);
537   TORCH_CHECK(scale_b.size(0) == 1);
538   TORCH_CHECK(scale_b.size(1) == b.size(1));
539 
540   TORCH_CHECK(a.stride(1) == 1);
541   TORCH_CHECK(a.stride(0) >= a.size(1));
542   TORCH_CHECK(b.stride(0) == 1);
543   TORCH_CHECK(b.stride(1) >= b.size(0));
544   TORCH_CHECK(scale_a.stride(0) == 1);
545   TORCH_CHECK(scale_b.stride(1) == 1);
546 
547   if (bias.has_value()) {
548     TORCH_CHECK(bias->device() == b.device());
549     TORCH_CHECK(bias->dtype() == at::kFloat || bias->dtype() == at::kBFloat16);
550     TORCH_CHECK(bias->dim() == 1);
551     TORCH_CHECK(bias->size(0) == b.size(1));
552     TORCH_CHECK(bias->stride(0) == 1);
553   }
554 
555   TORCH_CHECK(out.device() == a.device());
556   TORCH_CHECK(out.dtype() == at::kBFloat16);
557   TORCH_CHECK(out.dim() == 2);
558   TORCH_CHECK(out.size(0) == a.size(0));
559   TORCH_CHECK(out.size(1) == b.size(1));
560   TORCH_CHECK(out.stride(1) == 1);
561   TORCH_CHECK(out.stride(0) >= out.size(1));
562 }
563 
564 } // namespace
565 
566 #endif // !defined(USE_ROCM)
567 
568 namespace at::cuda::detail {
f8f8bf16_rowwise(at::Tensor XQ,at::Tensor WQ,at::Tensor x_scale,at::Tensor w_scale,std::optional<at::Tensor> bias,bool use_fast_accum,at::Tensor & out)569 void f8f8bf16_rowwise(
570     at::Tensor XQ, // FP8
571     at::Tensor WQ, // FP8
572     at::Tensor x_scale, // FP32
573     at::Tensor w_scale, // FP32
574     std::optional<at::Tensor> bias, // BF16
575     bool use_fast_accum,
576     at::Tensor& out) {
577 #if defined(BUILD_ROWWISE_FP8_KERNEL)
578   check_inputs(XQ, WQ, x_scale, w_scale, bias, out);
579 
580   dispatch_fp8_rowwise_kernel_on_bias_dtype(
581       XQ, WQ, x_scale, w_scale, bias, use_fast_accum, out);
582 #else // BUILD_ROWWISE_FP8_KERNEL
583   TORCH_CHECK(
584       false, "Rowwise scaling is not currenlty supported on your device");
585 #endif
586 }
587 
588 } // namespace at::cuda::detail
589