xref: /aosp_15_r20/external/pytorch/test/test_matmul_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: linear algebra"]
2
3import unittest
4from itertools import product
5from functools import partial
6from typing import Optional
7import re
8
9import torch
10
11from torch.quantization._quantized_conversions import (
12    pack_int4_to_int8,
13    quantized_weight_reorder_for_mixed_dtypes_linear_cutlass,
14)
15
16from torch.testing import make_tensor
17from torch.testing._internal.common_cuda import (
18    SM53OrLater,
19    SM90OrLater,
20    _get_torch_cuda_version,
21    PLATFORM_SUPPORTS_FP8
22)
23from torch.testing._internal.common_device_type import (
24    dtypes,
25    instantiate_device_type_tests,
26    onlyCUDA,
27    tol as xtol,
28    toleranceOverride,
29)
30
31from torch.testing._internal.common_utils import (
32    IS_ARM64,
33    IS_JETSON,
34    IS_WINDOWS,
35    parametrize,
36    run_tests,
37    skipIfRocmVersionLessThan,
38    TEST_WITH_ROCM,
39    skipIfRocm,
40    TestCase,
41)
42
43_IS_SM8X = False
44if torch.cuda.is_available():
45    _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
46
47# Protects against includes accidentally setting the default dtype
48assert torch.get_default_dtype() is torch.float32
49
50
51@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
52class TestMatmulCuda(TestCase):
53    def setUp(self):
54        super(self.__class__, self).setUp()
55        torch.backends.cuda.matmul.allow_tf32 = False
56
57    def tearDown(self):
58        torch.backends.cuda.matmul.allow_tf32 = True
59        super(self.__class__, self).tearDown()
60
61    def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False):
62        #
63        # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
64        # results from the CUDA invocation of torch.addmm and the CPU invocation
65        # (which does not use CUDA backend).
66        #
67        # Get dims
68        n, m, p = (size + 1, size, size + 2)
69        # Disable reduced precision reductions in BFloat16 to bypass some kernels
70        # which fail the threshold check
71        orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
72        orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
73        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
74        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
75        # Make random tensors on CPU (seed set on common_utils.py import)
76        # (Not using numpy because it does not support bfloat16)
77        make_arg = partial(make_tensor, dtype=dtype, device="cpu")
78        m_beta = make_arg(1)
79        m_input = make_arg((n, p))
80        m_1 = make_arg((n, m))
81        m_2 = make_arg((m, p))
82        # *(B)FLOAT16 Special Handling*
83        # Backend does not tensorize float16 on CPU,
84        # and bloat16 may present accuracy issues,
85        # so convert to float32 for these cases
86        # (but keep same for other types, e.g. float32 and int*)
87        if dtype == torch.float16 or dtype == torch.bfloat16:
88            m_beta = m_beta.to(dtype=torch.float32)
89            m_input = m_input.to(dtype=torch.float32)
90            m_1 = m_1.to(dtype=torch.float32)
91            m_2 = m_2.to(dtype=torch.float32)
92        # Get CPU result
93        res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
94        # *(B)FLOAT16 Special Handling*``
95        # Convert back to (b)float16
96        if dtype == torch.float16 or dtype == torch.bfloat16:
97            m_beta = m_beta.to(dtype=dtype)
98            m_input = m_input.to(dtype=dtype)
99            m_1 = m_1.to(dtype=dtype)
100            m_2 = m_2.to(dtype=dtype)
101            res_cpu = res_cpu.to(dtype=dtype)
102        # Move arg tensors to CUDA
103        m_beta = m_beta.to("cuda")
104        m_input = m_input.to("cuda")
105        m_1 = m_1.to("cuda")
106        m_2 = m_2.to("cuda")
107        # Get CUDA result
108        res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
109        # Move to CPU for comparison
110        res_cuda = res_cuda.to("cpu")
111        # Compare
112        self.assertEqual(res_cpu, res_cuda)
113        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
114        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
115
116    @onlyCUDA
117    @skipIfRocmVersionLessThan((5, 2))
118    # imported 'tol' as 'xtol' to avoid aliasing in code above
119    @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
120                        torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
121                        torch.float32: xtol(atol=1e-1, rtol=1e-1)})
122    @dtypes(torch.float16, torch.bfloat16, torch.float32)
123    @parametrize("size", [100, 1000, 10000])
124    def test_cublas_addmm(self, size: int, dtype: torch.dtype):
125        self.cublas_addmm(size, dtype, False)
126
127    @onlyCUDA
128    @skipIfRocmVersionLessThan((5, 2))
129    # imported 'tol' as 'xtol' to avoid aliasing in code above
130    @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
131                        torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
132    @dtypes(torch.float16, torch.bfloat16)
133    @parametrize("size", [100, 1000, 10000])
134    def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
135        self.cublas_addmm(size, dtype, True)
136
137    @onlyCUDA
138    @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
139    @dtypes(torch.float16)
140    def test_cublas_addmm_alignment(self, dtype):
141        device = 'cuda'
142        # perturb X, A, or B alignment
143        for idx in range(0, 3):
144            for offset in range(1, 3):
145                offsets = [0, 0, 0]
146                offsets[idx] = offset
147                x_offset, a_offset, b_offset = offsets
148                A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device)
149                A = A[a_offset:].reshape(5120, 2560)
150                X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device)
151                X = X[x_offset:].reshape(26, 1, 2560)
152                B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device)
153                B = B[b_offset:].reshape(5120)
154                out = torch.nn.functional.linear(X, A, B)
155                self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
156
157    @onlyCUDA
158    @unittest.skipIf(IS_JETSON, "Too large for Jetson")
159    @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)})
160    @dtypes(*([torch.float32, torch.float16] +
161              [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
162    @parametrize(
163        "batch_size, N, M, P",
164        [(2, 100, 100, 100),
165         (2, 1000, 1000, 1000),
166         (1, 10000, 1000, 10000),
167         (1, 10000, 10000, 10000)],
168        name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}",
169    )
170    @skipIfRocm
171    def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype):
172        cpu_dtype = dtype
173        if dtype == torch.float16 or dtype == torch.bfloat16:
174            cpu_dtype = torch.float32
175
176        M1 = torch.rand((N, M), device=device, dtype=dtype)
177        M2 = torch.rand((M, P), device=device, dtype=dtype)
178        A = torch.rand((N, P), device=device, dtype=dtype)
179
180        def _convert_to_cpu(t):
181            return t.to(device='cpu', dtype=cpu_dtype)
182        M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A])
183
184        # linear
185        out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype)
186        out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu()
187        self.assertEqual(out1_cpu, out1_gpu)
188        # test multiply the identity matrix
189        if N == M and M == P:
190            M2_eye = torch.eye(N, device=device, dtype=dtype)
191            out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A))
192            self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu())
193
194        # baddbmm
195        def _expand_to_batch(t: torch.Tensor):
196            return t.expand((batch_size, ) + t.size())
197        alpha, beta = 1.0, 1.0
198        M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu])
199
200        out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype)
201        out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu()
202        self.assertEqual(out2_cpu, out2_gpu)
203        # test multiply the identity matrix
204        if N == M and M == P:
205            M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N)
206            out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha)
207            self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu())
208
209        # cross comparison
210        self.assertEqual(out1_gpu, out2_gpu[0])
211
212
213f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
214
215if torch.version.hip:
216    e4m3_type = torch.float8_e4m3fnuz
217    e5m2_type = torch.float8_e5m2fnuz
218    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
219    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
220else:
221    e4m3_type = torch.float8_e4m3fn
222    e5m2_type = torch.float8_e5m2
223    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
224    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
225
226# avoid division by zero when calculating scale
227EPS = 1e-12
228
229def amax_to_scale(
230    amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
231):
232    """ Converts the amax value of a tensor to the fp8 scale.
233    Args:
234        amax: The amax value of the tensor.
235        float8_dtype: the float8 dtype.
236        orig_dtype: The original dtype of the tensor.
237    """
238    scale = torch.empty_like(amax, dtype=torch.float32)
239    if float8_dtype == e4m3_type:
240        res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
241    elif float8_dtype == e5m2_type:
242        res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
243    else:
244        raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
245
246    # Ensure the scale is representable in float16,
247    # this helps when amax is small. We are assuming that we don't need
248    # to care about this for float32/bfloat16
249    if orig_dtype is torch.float16:
250        res = torch.clamp(res, max=torch.finfo(torch.float16).max)
251
252    scale.copy_(res)
253    return scale
254
255def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
256    if dim is None:
257        amax = torch.max(torch.abs(x))
258    else:
259        amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
260
261    return amax_to_scale(amax, float8_dtype, x.dtype)
262
263def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
264    # naive implementation: dq -> op -> q
265    x_fp32 = x.to(torch.float) / x_scale
266    y_fp32 = y.to(torch.float) / y_scale
267    out_fp32 = torch.mm(x_fp32, y_fp32)
268
269    return out_fp32.to(out_dtype)
270
271def addmm_float8_unwrapped(
272    a_data: torch.Tensor,
273    a_scale: torch.Tensor,
274    b_data: torch.Tensor,
275    b_scale: torch.tensor,
276    output_dtype: torch.dtype,
277    output_scale: Optional[torch.Tensor],
278    bias: Optional[torch.Tensor] = None,
279) -> torch.Tensor:
280    a_inverse_scale = a_scale.reciprocal()
281    b_inverse_scale = b_scale.reciprocal()
282    if output_dtype == torch.float32 and bias is not None:
283        # Bias is not supported by _scaled_mm when output is fp32
284        output = torch._scaled_mm(
285            a_data,
286            b_data,
287            scale_a=a_inverse_scale,
288            scale_b=b_inverse_scale,
289            scale_result=output_scale,
290            out_dtype=output_dtype,
291        )
292        output += bias
293        return output
294    output = torch._scaled_mm(
295        a_data,
296        b_data,
297        bias=bias,
298        scale_a=a_inverse_scale,
299        scale_b=b_inverse_scale,
300        scale_result=output_scale,
301        out_dtype=output_dtype,
302    )
303    return output
304
305def mm_float8(
306    a: torch.Tensor,
307    b: torch.Tensor,
308    a_scale: torch.Tensor,
309    b_scale: torch.Tensor,
310    output_dtype: torch.dtype,  # output dtype
311    output_scale: Optional[torch.Tensor] = None,  # output scale, precomputed
312) -> torch.Tensor:
313    return addmm_float8_unwrapped(
314        a, a_scale, b, b_scale, output_dtype, output_scale
315    )
316
317def to_fp8_saturated(
318    x: torch.Tensor,
319    fp8_dtype: torch.dtype
320):
321    if fp8_dtype == e4m3_type:
322        x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
323    elif fp8_dtype == e5m2_type:
324        x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
325    else:
326        raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
327
328    return x.to(fp8_dtype)
329
330@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
331class TestFP8MatmulCuda(TestCase):
332
333    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
334    def _test_tautological_mm(self, device: str = "cuda",
335                              x_dtype: torch.dtype = e4m3_type,
336                              y_dtype: torch.dtype = e4m3_type,
337                              out_dtype: Optional[torch.dtype] = None,
338                              size: int = 16) -> None:
339        x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
340        y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
341        out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
342        scale_a = torch.tensor(1.0, device=device)
343        scale_b = torch.tensor(1.0, device=device)
344        out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
345        if out_dtype is not None:
346            self.assertEqual(out_dtype, out_fp8.dtype)
347        self.assertEqual(out_fp32, out_fp8.to(torch.float))
348
349    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
350    def test_float8_basics(self, device) -> None:
351        self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
352        # hipblaslt does not yet support mixed e4m3_type input
353        if torch.version.hip is None:
354            self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
355            self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
356        # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
357        with self.assertRaises(RuntimeError):
358            self._test_tautological_mm(device, e5m2_type, e5m2_type)
359
360        self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
361        self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
362        # hipblaslt does not yet support bfloat16 output
363        if torch.version.hip is None:
364            self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
365        with self.assertRaises(RuntimeError):
366            self._test_tautological_mm(device, out_dtype=e5m2_type)
367
368    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
369    def test_float8_scale(self, device) -> None:
370        size = (16, 16)
371        x = torch.full(size, .5, device=device, dtype=e4m3_type)
372        # hipblaslt does not yet support mixed e4m3_type input
373        y_type = e4m3_type if torch.version.hip else e5m2_type
374        y = torch.full(size, .5, device=device, dtype=y_type).t()
375        scale_a = torch.tensor(1.5, device=device)
376        scale_b = torch.tensor(0.66, device=device)
377        out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
378        self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
379        out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
380        self.assertEqual(out_fp8, out_fp8_s)
381
382    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
383    @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
384    def test_scaled_mm_vs_emulated(self, base_dtype):
385        torch.manual_seed(42)
386        input_dtype = e4m3_type
387        output_dtype = base_dtype
388        compare_type = torch.float32
389
390        x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
391        y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
392
393        x_scale = tensor_to_scale(x, input_dtype).float()
394        y_scale = tensor_to_scale(y, input_dtype).float()
395
396        x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
397        y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
398
399        # Calculate actual F8 mm
400        out_scaled_mm = mm_float8(
401            x_fp8,
402            y_fp8,
403            a_scale=x_scale,
404            b_scale=y_scale,
405            output_dtype=output_dtype
406        )
407
408        # Calculate emulated F8 mm
409        out_emulated = mm_float8_emulated(
410            x_fp8,
411            x_scale,
412            y_fp8,
413            y_scale,
414            output_dtype
415        )
416
417        if output_dtype != base_dtype:
418            out_scaled_mm = out_scaled_mm.to(compare_type)
419            out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
420
421            out_emulated = out_emulated.to(compare_type)
422            out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
423
424        if base_dtype in {torch.bfloat16, torch.float16}:
425            atol, rtol = 7e-2, 7e-2
426        else:
427            atol, rtol = 3e-3, 3e-3
428
429        torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
430
431    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
432    @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
433    def test_scaled_mm_change_stride(self, base_dtype):
434        torch.manual_seed(42)
435        input_dtype = e4m3_type
436        output_dtype = base_dtype
437        compare_type = torch.float32
438
439        x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
440        y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
441
442        x_scale = tensor_to_scale(x, input_dtype).float()
443        y_scale = tensor_to_scale(y, input_dtype).float()
444
445        x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
446        y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
447
448        # Calculate actual F8 mm
449        out_scaled_mm = mm_float8(
450            x_fp8,
451            y_fp8,
452            a_scale=x_scale,
453            b_scale=y_scale,
454            output_dtype=output_dtype
455        )
456
457        # Calculate emulated F8 mm
458        out_emulated = mm_float8_emulated(
459            x_fp8,
460            x_scale,
461            y_fp8,
462            y_scale,
463            output_dtype
464        )
465
466        if output_dtype != base_dtype:
467            out_scaled_mm = out_scaled_mm.to(compare_type)
468            out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
469
470            out_emulated = out_emulated.to(compare_type)
471            out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
472
473        if base_dtype in {torch.bfloat16, torch.float16}:
474            atol, rtol = 7e-2, 7e-2
475        else:
476            atol, rtol = 3e-3, 3e-3
477
478        torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
479
480    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
481    def test_float8_bias(self, device) -> None:
482        (k, l, m) = (16, 48, 32)
483        x = torch.ones((k, l), device=device).to(e4m3_type)
484        y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
485        bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
486        scale_a = torch.tensor(1.0, device=device)
487        scale_b = torch.tensor(1.0, device=device)
488        out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
489        outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
490        # this fails on ROCm currently because hipblaslt doesn't have amax op
491        out_fp32 = out_fp8.to(torch.float32)
492        outb_fp32 = outb_fp8.to(torch.float32)
493        difference = torch.abs(out_fp32 - outb_fp32)
494        self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
495
496    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
497    @parametrize("bias", [True, False])
498    def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
499        x = torch.rand((17, 16), device=device).to(e4m3_type)
500        y = torch.rand((16, 16), device=device).to(e4m3_type).t()
501        scale_a = torch.tensor(1.0, device=device)
502        scale_b = torch.tensor(1.0, device=device)
503        input_bias = None
504        if bias:
505            input_bias = torch.rand((16,), device=device).to(torch.half)
506        _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
507
508    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
509    def test_float8_bias_relu_edgecase(self, device) -> None:
510        (k, l, m) = (16, 48, 32)
511        x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
512        y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
513        bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
514        scale_a = torch.tensor(1.0, device=device)
515        scale_b = torch.tensor(1.0, device=device)
516        outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
517        outb_fp32 = outb_fp8.to(torch.float32)
518        self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
519
520    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
521    def test_float32_output_errors_with_bias(self, device) -> None:
522        (k, l, m) = (16, 48, 32)
523        x = torch.rand((k, l), device=device).to(e4m3_type)
524        y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
525        scale_a = torch.tensor(1.0, device=device)
526        scale_b = torch.tensor(1.0, device=device)
527        bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
528        self.assertRaisesRegex(
529            RuntimeError,
530            "Bias is not supported when out_dtype is set to Float32",
531            lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
532        )
533
534    @unittest.skipIf(PLATFORM_SUPPORTS_FP8,
535                     "This test is only for devices with compute capability < 8.9")
536    def test_error_message_fp8_pre_sm89(self, device) -> None:
537        (k, l, m) = (16, 48, 32)
538        x = torch.rand((k, l), device=device).to(e4m3_type)
539        y = torch.rand((m, l), device=device).to(e4m3_type).t()
540        scale_a = torch.tensor(1.0, device=device)
541        scale_b = torch.tensor(1.0, device=device)
542        self.assertRaisesRegex(
543            RuntimeError,
544            r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
545            lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
546        )
547
548    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
549    def test_float8_scale_fast_accum(self, device) -> None:
550        size = (16, 16)
551        x = torch.full(size, .5, device=device, dtype=e4m3_type)
552        # hipblaslt does not yet support mixed e4m3_type input
553        y_type = e4m3_type if torch.version.hip else e5m2_type
554        y = torch.full(size, .5, device=device, dtype=y_type).t()
555        scale_a = torch.tensor(1.5, device=device)
556        scale_b = torch.tensor(0.66, device=device)
557        out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
558        self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
559        out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
560        self.assertEqual(out_fp8, out_fp8_s)
561
562    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
563    @skipIfRocm()
564    @parametrize("use_fast_accum", [True, False])
565    def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
566        M, K, N = (1024, 512, 2048)
567        fill_value = 0.5
568        x = torch.full((M, K), fill_value, device=device)
569        y = torch.full((N, K), fill_value, device=device)
570
571        x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
572        y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
573
574        x_fp8 = x.to(torch.float8_e4m3fn)
575        y_fp8 = y.to(torch.float8_e4m3fn).t()
576
577        out_fp8 = torch._scaled_mm(
578            x_fp8,
579            y_fp8,
580            scale_a=x_scales,
581            scale_b=y_scales,
582            out_dtype=torch.bfloat16,
583            use_fast_accum=use_fast_accum,
584        )
585        self.assertEqual(
586            out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
587        )
588
589    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
590    @skipIfRocm()
591    def test_float8_error_messages(self, device) -> None:
592        M, K, N = (1024, 512, 2048)
593        fill_value = 0.5
594        x = torch.full((M, K), fill_value, device=device)
595        y = torch.full((N, K), fill_value, device=device)
596
597        x_fp8 = x.to(torch.float8_e4m3fn)
598        y_fp8 = y.to(torch.float8_e4m3fn).t()
599
600        with self.assertRaisesRegex(
601            RuntimeError,
602            re.escape(
603                "For RowWise scaling, scale_a should be (1024, 1) and scale_b "
604                "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
605            ),
606        ):
607            torch._scaled_mm(
608                x_fp8,
609                y_fp8,
610                scale_a=torch.ones((1, 1), device="cuda"),
611                scale_b=torch.ones((1, 2), device="cuda"),
612                out_dtype=torch.bfloat16,
613            )
614
615        with self.assertRaisesRegex(
616            RuntimeError,
617            re.escape(
618                " For RowWise scaling, scale_a should be (1024, 1) and scale_b "
619                "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
620            ),
621        ):
622            torch._scaled_mm(
623                x_fp8,
624                y_fp8,
625                scale_a=torch.ones((M, 1), device="cuda"),
626                scale_b=torch.ones((1, N + 1), device="cuda"),
627                out_dtype=torch.bfloat16,
628            )
629        with self.assertRaisesRegex(
630            RuntimeError,
631            re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
632        ):
633            torch._scaled_mm(
634                x_fp8,
635                y_fp8,
636                scale_a=torch.ones((M), device="cuda"),
637                scale_b=torch.ones((N, N), device="cuda"),
638                out_dtype=torch.bfloat16,
639            )
640
641        with self.assertRaisesRegex(
642            RuntimeError,
643            re.escape(
644                "Both scale_a and scale_b must be contiguous for RowWise scaling."
645            ),
646        ):
647            torch._scaled_mm(
648                x_fp8,
649                y_fp8,
650                scale_a=torch.ones((M, 1), device="cuda"),
651                scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
652                out_dtype=torch.bfloat16,
653            )
654
655        with self.assertRaisesRegex(
656            RuntimeError,
657            re.escape("For RowWise scaling the second input is required to be a float8_e4m3fn dtype."),
658        ):
659            torch._scaled_mm(
660                x_fp8,
661                y_fp8.to(torch.float8_e5m2),
662                scale_a=torch.ones((M, 1), device="cuda"),
663                scale_b=torch.ones((1, N), device="cuda"),
664                out_dtype=torch.bfloat16,
665            )
666
667    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
668    @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific")
669    @skipIfRocm()
670    @parametrize("base_dtype", [torch.bfloat16])
671    def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
672        torch.manual_seed(42)
673        input_dtype = e4m3_type
674        output_dtype = base_dtype
675
676        x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
677        y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
678
679        x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
680        y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
681
682        x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
683        y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
684
685        # Calculate actual F8 mm
686        out_scaled_mm = mm_float8(
687            x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
688        )
689
690        # Calculate emulated F8 mm
691        out_emulated = mm_float8_emulated(
692            x_fp8, x_scales, y_fp8, y_scales, output_dtype
693        )
694
695        if base_dtype in {torch.bfloat16, torch.float16}:
696            atol, rtol = 7e-2, 7e-2
697        else:
698            atol, rtol = 2e-3, 2e-3
699
700        torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
701
702
703@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
704@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
705@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
706class TestMixedDtypesLinearCuda(TestCase):
707    @dtypes(torch.float16, torch.bfloat16)
708    def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"):
709        version = _get_torch_cuda_version()
710        if version < (11, 8):
711            self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+")
712
713        def run_test(
714            batch_shape,
715            m,
716            n,
717            k,
718            add_bias,
719            activation,
720            dtype,
721            dtypeq,
722            device,
723            rtol,
724            atol,
725        ):
726            if not add_bias and activation != "none":
727                return
728
729            val_lo, val_hi = -1, 1
730            valq_lo, valq_hi = -2, 2
731            input = make_tensor(
732                *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device
733            )
734            weight = make_tensor(
735                n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device
736            )
737            scale = make_tensor(
738                (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
739            )
740            bias = (
741                make_tensor(
742                    (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
743                )
744                if add_bias
745                else None
746            )
747
748            input_ref = input.reshape(-1, input.shape[-1])
749
750            # First, test plain multiplication.
751            weight_ref = weight.T.to(input.dtype) * scale.view(1, n)
752            weightq = (
753                pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T
754            )
755            output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n)
756            output = torch.ops.aten._mixed_dtypes_linear(
757                input,
758                quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
759                    weightq, dtypeq, transpose=False
760                ),
761                scale,
762            )
763            torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
764
765            # Second, test the linear operator itself.
766            weight_ref = weight.to(input.dtype) * scale.view(n, 1)
767            weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight
768            bias_ref = bias.view(1, n) if add_bias else None
769            output_ref = torch.nn.functional.linear(
770                input_ref, weight_ref, bias=bias_ref
771            ).reshape(*input.shape[:-1], n)
772            if activation == "relu":
773                relu = torch.nn.ReLU()
774                output_ref = relu(output_ref)
775            elif activation == "silu":
776                silu = torch.nn.SiLU()
777                output_ref = silu(output_ref)
778            output = torch.ops.aten._mixed_dtypes_linear(
779                input,
780                quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
781                    weightq, dtypeq, transpose=True
782                ),
783                scale,
784                bias=bias,
785                activation=activation,
786            )
787            torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
788
789        dtypeqs = [torch.int8, torch.quint4x2]
790        batch_shapes = [[], [2], [2, 1]]
791        shapes = [
792            [8, 64, 64],
793            [8, 64, 128],
794            [8, 128, 64],
795            [8, 128, 128],
796            [8, 128, 192],
797            [8, 128, 256],
798            [8, 256, 128],
799            [8, 256, 384],
800            [8, 384, 256],
801        ]
802        activations = [None, "relu", "silu"]
803        rtol, atol = 1e-3, 1e-3
804        if dtype == torch.bfloat16:
805            rtol, atol = 1e-2, 1e-3
806        for dtypeq, batch_shape, (m, n, k), add_bias, activation in product(
807            dtypeqs, batch_shapes, shapes, (False, True), activations
808        ):
809            run_test(
810                batch_shape,
811                m,
812                n,
813                k,
814                add_bias,
815                activation,
816                dtype,
817                dtypeq,
818                device,
819                rtol,
820                atol,
821            )
822
823instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
824instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
825instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
826
827if __name__ == '__main__':
828    TestCase._default_dtype_check_enabled = True
829    run_tests()
830