xref: /aosp_15_r20/external/pytorch/test/test_matmul_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: linear algebra"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport unittest
4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
5*da0073e9SAndroid Build Coastguard Workerfrom functools import partial
6*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional
7*da0073e9SAndroid Build Coastguard Workerimport re
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerfrom torch.quantization._quantized_conversions import (
12*da0073e9SAndroid Build Coastguard Worker    pack_int4_to_int8,
13*da0073e9SAndroid Build Coastguard Worker    quantized_weight_reorder_for_mixed_dtypes_linear_cutlass,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import (
18*da0073e9SAndroid Build Coastguard Worker    SM53OrLater,
19*da0073e9SAndroid Build Coastguard Worker    SM90OrLater,
20*da0073e9SAndroid Build Coastguard Worker    _get_torch_cuda_version,
21*da0073e9SAndroid Build Coastguard Worker    PLATFORM_SUPPORTS_FP8
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
24*da0073e9SAndroid Build Coastguard Worker    dtypes,
25*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
26*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
27*da0073e9SAndroid Build Coastguard Worker    tol as xtol,
28*da0073e9SAndroid Build Coastguard Worker    toleranceOverride,
29*da0073e9SAndroid Build Coastguard Worker)
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
32*da0073e9SAndroid Build Coastguard Worker    IS_ARM64,
33*da0073e9SAndroid Build Coastguard Worker    IS_JETSON,
34*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
35*da0073e9SAndroid Build Coastguard Worker    parametrize,
36*da0073e9SAndroid Build Coastguard Worker    run_tests,
37*da0073e9SAndroid Build Coastguard Worker    skipIfRocmVersionLessThan,
38*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
39*da0073e9SAndroid Build Coastguard Worker    skipIfRocm,
40*da0073e9SAndroid Build Coastguard Worker    TestCase,
41*da0073e9SAndroid Build Coastguard Worker)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker_IS_SM8X = False
44*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available():
45*da0073e9SAndroid Build Coastguard Worker    _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker# Protects against includes accidentally setting the default dtype
48*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() is torch.float32
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm")
52*da0073e9SAndroid Build Coastguard Workerclass TestMatmulCuda(TestCase):
53*da0073e9SAndroid Build Coastguard Worker    def setUp(self):
54*da0073e9SAndroid Build Coastguard Worker        super(self.__class__, self).setUp()
55*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = False
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
58*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_tf32 = True
59*da0073e9SAndroid Build Coastguard Worker        super(self.__class__, self).tearDown()
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False):
62*da0073e9SAndroid Build Coastguard Worker        #
63*da0073e9SAndroid Build Coastguard Worker        # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between
64*da0073e9SAndroid Build Coastguard Worker        # results from the CUDA invocation of torch.addmm and the CPU invocation
65*da0073e9SAndroid Build Coastguard Worker        # (which does not use CUDA backend).
66*da0073e9SAndroid Build Coastguard Worker        #
67*da0073e9SAndroid Build Coastguard Worker        # Get dims
68*da0073e9SAndroid Build Coastguard Worker        n, m, p = (size + 1, size, size + 2)
69*da0073e9SAndroid Build Coastguard Worker        # Disable reduced precision reductions in BFloat16 to bypass some kernels
70*da0073e9SAndroid Build Coastguard Worker        # which fail the threshold check
71*da0073e9SAndroid Build Coastguard Worker        orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
72*da0073e9SAndroid Build Coastguard Worker        orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
73*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision
74*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision
75*da0073e9SAndroid Build Coastguard Worker        # Make random tensors on CPU (seed set on common_utils.py import)
76*da0073e9SAndroid Build Coastguard Worker        # (Not using numpy because it does not support bfloat16)
77*da0073e9SAndroid Build Coastguard Worker        make_arg = partial(make_tensor, dtype=dtype, device="cpu")
78*da0073e9SAndroid Build Coastguard Worker        m_beta = make_arg(1)
79*da0073e9SAndroid Build Coastguard Worker        m_input = make_arg((n, p))
80*da0073e9SAndroid Build Coastguard Worker        m_1 = make_arg((n, m))
81*da0073e9SAndroid Build Coastguard Worker        m_2 = make_arg((m, p))
82*da0073e9SAndroid Build Coastguard Worker        # *(B)FLOAT16 Special Handling*
83*da0073e9SAndroid Build Coastguard Worker        # Backend does not tensorize float16 on CPU,
84*da0073e9SAndroid Build Coastguard Worker        # and bloat16 may present accuracy issues,
85*da0073e9SAndroid Build Coastguard Worker        # so convert to float32 for these cases
86*da0073e9SAndroid Build Coastguard Worker        # (but keep same for other types, e.g. float32 and int*)
87*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16 or dtype == torch.bfloat16:
88*da0073e9SAndroid Build Coastguard Worker            m_beta = m_beta.to(dtype=torch.float32)
89*da0073e9SAndroid Build Coastguard Worker            m_input = m_input.to(dtype=torch.float32)
90*da0073e9SAndroid Build Coastguard Worker            m_1 = m_1.to(dtype=torch.float32)
91*da0073e9SAndroid Build Coastguard Worker            m_2 = m_2.to(dtype=torch.float32)
92*da0073e9SAndroid Build Coastguard Worker        # Get CPU result
93*da0073e9SAndroid Build Coastguard Worker        res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
94*da0073e9SAndroid Build Coastguard Worker        # *(B)FLOAT16 Special Handling*``
95*da0073e9SAndroid Build Coastguard Worker        # Convert back to (b)float16
96*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16 or dtype == torch.bfloat16:
97*da0073e9SAndroid Build Coastguard Worker            m_beta = m_beta.to(dtype=dtype)
98*da0073e9SAndroid Build Coastguard Worker            m_input = m_input.to(dtype=dtype)
99*da0073e9SAndroid Build Coastguard Worker            m_1 = m_1.to(dtype=dtype)
100*da0073e9SAndroid Build Coastguard Worker            m_2 = m_2.to(dtype=dtype)
101*da0073e9SAndroid Build Coastguard Worker            res_cpu = res_cpu.to(dtype=dtype)
102*da0073e9SAndroid Build Coastguard Worker        # Move arg tensors to CUDA
103*da0073e9SAndroid Build Coastguard Worker        m_beta = m_beta.to("cuda")
104*da0073e9SAndroid Build Coastguard Worker        m_input = m_input.to("cuda")
105*da0073e9SAndroid Build Coastguard Worker        m_1 = m_1.to("cuda")
106*da0073e9SAndroid Build Coastguard Worker        m_2 = m_2.to("cuda")
107*da0073e9SAndroid Build Coastguard Worker        # Get CUDA result
108*da0073e9SAndroid Build Coastguard Worker        res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item())
109*da0073e9SAndroid Build Coastguard Worker        # Move to CPU for comparison
110*da0073e9SAndroid Build Coastguard Worker        res_cuda = res_cuda.to("cpu")
111*da0073e9SAndroid Build Coastguard Worker        # Compare
112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res_cpu, res_cuda)
113*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16
114*da0073e9SAndroid Build Coastguard Worker        torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
117*da0073e9SAndroid Build Coastguard Worker    @skipIfRocmVersionLessThan((5, 2))
118*da0073e9SAndroid Build Coastguard Worker    # imported 'tol' as 'xtol' to avoid aliasing in code above
119*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1),
120*da0073e9SAndroid Build Coastguard Worker                        torch.bfloat16: xtol(atol=1e-1, rtol=1e-1),
121*da0073e9SAndroid Build Coastguard Worker                        torch.float32: xtol(atol=1e-1, rtol=1e-1)})
122*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16, torch.float32)
123*da0073e9SAndroid Build Coastguard Worker    @parametrize("size", [100, 1000, 10000])
124*da0073e9SAndroid Build Coastguard Worker    def test_cublas_addmm(self, size: int, dtype: torch.dtype):
125*da0073e9SAndroid Build Coastguard Worker        self.cublas_addmm(size, dtype, False)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
128*da0073e9SAndroid Build Coastguard Worker    @skipIfRocmVersionLessThan((5, 2))
129*da0073e9SAndroid Build Coastguard Worker    # imported 'tol' as 'xtol' to avoid aliasing in code above
130*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1),
131*da0073e9SAndroid Build Coastguard Worker                        torch.bfloat16: xtol(atol=1e1, rtol=2e-1)})
132*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
133*da0073e9SAndroid Build Coastguard Worker    @parametrize("size", [100, 1000, 10000])
134*da0073e9SAndroid Build Coastguard Worker    def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype):
135*da0073e9SAndroid Build Coastguard Worker        self.cublas_addmm(size, dtype, True)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
138*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)})
139*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16)
140*da0073e9SAndroid Build Coastguard Worker    def test_cublas_addmm_alignment(self, dtype):
141*da0073e9SAndroid Build Coastguard Worker        device = 'cuda'
142*da0073e9SAndroid Build Coastguard Worker        # perturb X, A, or B alignment
143*da0073e9SAndroid Build Coastguard Worker        for idx in range(0, 3):
144*da0073e9SAndroid Build Coastguard Worker            for offset in range(1, 3):
145*da0073e9SAndroid Build Coastguard Worker                offsets = [0, 0, 0]
146*da0073e9SAndroid Build Coastguard Worker                offsets[idx] = offset
147*da0073e9SAndroid Build Coastguard Worker                x_offset, a_offset, b_offset = offsets
148*da0073e9SAndroid Build Coastguard Worker                A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device)
149*da0073e9SAndroid Build Coastguard Worker                A = A[a_offset:].reshape(5120, 2560)
150*da0073e9SAndroid Build Coastguard Worker                X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device)
151*da0073e9SAndroid Build Coastguard Worker                X = X[x_offset:].reshape(26, 1, 2560)
152*da0073e9SAndroid Build Coastguard Worker                B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device)
153*da0073e9SAndroid Build Coastguard Worker                B = B[b_offset:].reshape(5120)
154*da0073e9SAndroid Build Coastguard Worker                out = torch.nn.functional.linear(X, A, B)
155*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
158*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_JETSON, "Too large for Jetson")
159*da0073e9SAndroid Build Coastguard Worker    @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)})
160*da0073e9SAndroid Build Coastguard Worker    @dtypes(*([torch.float32, torch.float16] +
161*da0073e9SAndroid Build Coastguard Worker              [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else []))
162*da0073e9SAndroid Build Coastguard Worker    @parametrize(
163*da0073e9SAndroid Build Coastguard Worker        "batch_size, N, M, P",
164*da0073e9SAndroid Build Coastguard Worker        [(2, 100, 100, 100),
165*da0073e9SAndroid Build Coastguard Worker         (2, 1000, 1000, 1000),
166*da0073e9SAndroid Build Coastguard Worker         (1, 10000, 1000, 10000),
167*da0073e9SAndroid Build Coastguard Worker         (1, 10000, 10000, 10000)],
168*da0073e9SAndroid Build Coastguard Worker        name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}",
169*da0073e9SAndroid Build Coastguard Worker    )
170*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm
171*da0073e9SAndroid Build Coastguard Worker    def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype):
172*da0073e9SAndroid Build Coastguard Worker        cpu_dtype = dtype
173*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.float16 or dtype == torch.bfloat16:
174*da0073e9SAndroid Build Coastguard Worker            cpu_dtype = torch.float32
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        M1 = torch.rand((N, M), device=device, dtype=dtype)
177*da0073e9SAndroid Build Coastguard Worker        M2 = torch.rand((M, P), device=device, dtype=dtype)
178*da0073e9SAndroid Build Coastguard Worker        A = torch.rand((N, P), device=device, dtype=dtype)
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker        def _convert_to_cpu(t):
181*da0073e9SAndroid Build Coastguard Worker            return t.to(device='cpu', dtype=cpu_dtype)
182*da0073e9SAndroid Build Coastguard Worker        M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A])
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        # linear
185*da0073e9SAndroid Build Coastguard Worker        out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype)
186*da0073e9SAndroid Build Coastguard Worker        out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu()
187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1_cpu, out1_gpu)
188*da0073e9SAndroid Build Coastguard Worker        # test multiply the identity matrix
189*da0073e9SAndroid Build Coastguard Worker        if N == M and M == P:
190*da0073e9SAndroid Build Coastguard Worker            M2_eye = torch.eye(N, device=device, dtype=dtype)
191*da0073e9SAndroid Build Coastguard Worker            out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A))
192*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu())
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        # baddbmm
195*da0073e9SAndroid Build Coastguard Worker        def _expand_to_batch(t: torch.Tensor):
196*da0073e9SAndroid Build Coastguard Worker            return t.expand((batch_size, ) + t.size())
197*da0073e9SAndroid Build Coastguard Worker        alpha, beta = 1.0, 1.0
198*da0073e9SAndroid Build Coastguard Worker        M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu])
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype)
201*da0073e9SAndroid Build Coastguard Worker        out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu()
202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out2_cpu, out2_gpu)
203*da0073e9SAndroid Build Coastguard Worker        # test multiply the identity matrix
204*da0073e9SAndroid Build Coastguard Worker        if N == M and M == P:
205*da0073e9SAndroid Build Coastguard Worker            M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N)
206*da0073e9SAndroid Build Coastguard Worker            out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha)
207*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu())
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker        # cross comparison
210*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1_gpu, out2_gpu[0])
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Workerf8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Workerif torch.version.hip:
216*da0073e9SAndroid Build Coastguard Worker    e4m3_type = torch.float8_e4m3fnuz
217*da0073e9SAndroid Build Coastguard Worker    e5m2_type = torch.float8_e5m2fnuz
218*da0073e9SAndroid Build Coastguard Worker    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
219*da0073e9SAndroid Build Coastguard Worker    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
220*da0073e9SAndroid Build Coastguard Workerelse:
221*da0073e9SAndroid Build Coastguard Worker    e4m3_type = torch.float8_e4m3fn
222*da0073e9SAndroid Build Coastguard Worker    e5m2_type = torch.float8_e5m2
223*da0073e9SAndroid Build Coastguard Worker    E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
224*da0073e9SAndroid Build Coastguard Worker    E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker# avoid division by zero when calculating scale
227*da0073e9SAndroid Build Coastguard WorkerEPS = 1e-12
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Workerdef amax_to_scale(
230*da0073e9SAndroid Build Coastguard Worker    amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
231*da0073e9SAndroid Build Coastguard Worker):
232*da0073e9SAndroid Build Coastguard Worker    """ Converts the amax value of a tensor to the fp8 scale.
233*da0073e9SAndroid Build Coastguard Worker    Args:
234*da0073e9SAndroid Build Coastguard Worker        amax: The amax value of the tensor.
235*da0073e9SAndroid Build Coastguard Worker        float8_dtype: the float8 dtype.
236*da0073e9SAndroid Build Coastguard Worker        orig_dtype: The original dtype of the tensor.
237*da0073e9SAndroid Build Coastguard Worker    """
238*da0073e9SAndroid Build Coastguard Worker    scale = torch.empty_like(amax, dtype=torch.float32)
239*da0073e9SAndroid Build Coastguard Worker    if float8_dtype == e4m3_type:
240*da0073e9SAndroid Build Coastguard Worker        res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
241*da0073e9SAndroid Build Coastguard Worker    elif float8_dtype == e5m2_type:
242*da0073e9SAndroid Build Coastguard Worker        res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
243*da0073e9SAndroid Build Coastguard Worker    else:
244*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    # Ensure the scale is representable in float16,
247*da0073e9SAndroid Build Coastguard Worker    # this helps when amax is small. We are assuming that we don't need
248*da0073e9SAndroid Build Coastguard Worker    # to care about this for float32/bfloat16
249*da0073e9SAndroid Build Coastguard Worker    if orig_dtype is torch.float16:
250*da0073e9SAndroid Build Coastguard Worker        res = torch.clamp(res, max=torch.finfo(torch.float16).max)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker    scale.copy_(res)
253*da0073e9SAndroid Build Coastguard Worker    return scale
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Workerdef tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None):
256*da0073e9SAndroid Build Coastguard Worker    if dim is None:
257*da0073e9SAndroid Build Coastguard Worker        amax = torch.max(torch.abs(x))
258*da0073e9SAndroid Build Coastguard Worker    else:
259*da0073e9SAndroid Build Coastguard Worker        amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker    return amax_to_scale(amax, float8_dtype, x.dtype)
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Workerdef mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor:
264*da0073e9SAndroid Build Coastguard Worker    # naive implementation: dq -> op -> q
265*da0073e9SAndroid Build Coastguard Worker    x_fp32 = x.to(torch.float) / x_scale
266*da0073e9SAndroid Build Coastguard Worker    y_fp32 = y.to(torch.float) / y_scale
267*da0073e9SAndroid Build Coastguard Worker    out_fp32 = torch.mm(x_fp32, y_fp32)
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker    return out_fp32.to(out_dtype)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Workerdef addmm_float8_unwrapped(
272*da0073e9SAndroid Build Coastguard Worker    a_data: torch.Tensor,
273*da0073e9SAndroid Build Coastguard Worker    a_scale: torch.Tensor,
274*da0073e9SAndroid Build Coastguard Worker    b_data: torch.Tensor,
275*da0073e9SAndroid Build Coastguard Worker    b_scale: torch.tensor,
276*da0073e9SAndroid Build Coastguard Worker    output_dtype: torch.dtype,
277*da0073e9SAndroid Build Coastguard Worker    output_scale: Optional[torch.Tensor],
278*da0073e9SAndroid Build Coastguard Worker    bias: Optional[torch.Tensor] = None,
279*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
280*da0073e9SAndroid Build Coastguard Worker    a_inverse_scale = a_scale.reciprocal()
281*da0073e9SAndroid Build Coastguard Worker    b_inverse_scale = b_scale.reciprocal()
282*da0073e9SAndroid Build Coastguard Worker    if output_dtype == torch.float32 and bias is not None:
283*da0073e9SAndroid Build Coastguard Worker        # Bias is not supported by _scaled_mm when output is fp32
284*da0073e9SAndroid Build Coastguard Worker        output = torch._scaled_mm(
285*da0073e9SAndroid Build Coastguard Worker            a_data,
286*da0073e9SAndroid Build Coastguard Worker            b_data,
287*da0073e9SAndroid Build Coastguard Worker            scale_a=a_inverse_scale,
288*da0073e9SAndroid Build Coastguard Worker            scale_b=b_inverse_scale,
289*da0073e9SAndroid Build Coastguard Worker            scale_result=output_scale,
290*da0073e9SAndroid Build Coastguard Worker            out_dtype=output_dtype,
291*da0073e9SAndroid Build Coastguard Worker        )
292*da0073e9SAndroid Build Coastguard Worker        output += bias
293*da0073e9SAndroid Build Coastguard Worker        return output
294*da0073e9SAndroid Build Coastguard Worker    output = torch._scaled_mm(
295*da0073e9SAndroid Build Coastguard Worker        a_data,
296*da0073e9SAndroid Build Coastguard Worker        b_data,
297*da0073e9SAndroid Build Coastguard Worker        bias=bias,
298*da0073e9SAndroid Build Coastguard Worker        scale_a=a_inverse_scale,
299*da0073e9SAndroid Build Coastguard Worker        scale_b=b_inverse_scale,
300*da0073e9SAndroid Build Coastguard Worker        scale_result=output_scale,
301*da0073e9SAndroid Build Coastguard Worker        out_dtype=output_dtype,
302*da0073e9SAndroid Build Coastguard Worker    )
303*da0073e9SAndroid Build Coastguard Worker    return output
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Workerdef mm_float8(
306*da0073e9SAndroid Build Coastguard Worker    a: torch.Tensor,
307*da0073e9SAndroid Build Coastguard Worker    b: torch.Tensor,
308*da0073e9SAndroid Build Coastguard Worker    a_scale: torch.Tensor,
309*da0073e9SAndroid Build Coastguard Worker    b_scale: torch.Tensor,
310*da0073e9SAndroid Build Coastguard Worker    output_dtype: torch.dtype,  # output dtype
311*da0073e9SAndroid Build Coastguard Worker    output_scale: Optional[torch.Tensor] = None,  # output scale, precomputed
312*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor:
313*da0073e9SAndroid Build Coastguard Worker    return addmm_float8_unwrapped(
314*da0073e9SAndroid Build Coastguard Worker        a, a_scale, b, b_scale, output_dtype, output_scale
315*da0073e9SAndroid Build Coastguard Worker    )
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Workerdef to_fp8_saturated(
318*da0073e9SAndroid Build Coastguard Worker    x: torch.Tensor,
319*da0073e9SAndroid Build Coastguard Worker    fp8_dtype: torch.dtype
320*da0073e9SAndroid Build Coastguard Worker):
321*da0073e9SAndroid Build Coastguard Worker    if fp8_dtype == e4m3_type:
322*da0073e9SAndroid Build Coastguard Worker        x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
323*da0073e9SAndroid Build Coastguard Worker    elif fp8_dtype == e5m2_type:
324*da0073e9SAndroid Build Coastguard Worker        x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
325*da0073e9SAndroid Build Coastguard Worker    else:
326*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}")
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker    return x.to(fp8_dtype)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found")
331*da0073e9SAndroid Build Coastguard Workerclass TestFP8MatmulCuda(TestCase):
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
334*da0073e9SAndroid Build Coastguard Worker    def _test_tautological_mm(self, device: str = "cuda",
335*da0073e9SAndroid Build Coastguard Worker                              x_dtype: torch.dtype = e4m3_type,
336*da0073e9SAndroid Build Coastguard Worker                              y_dtype: torch.dtype = e4m3_type,
337*da0073e9SAndroid Build Coastguard Worker                              out_dtype: Optional[torch.dtype] = None,
338*da0073e9SAndroid Build Coastguard Worker                              size: int = 16) -> None:
339*da0073e9SAndroid Build Coastguard Worker        x_fp8 = torch.rand(size, size, device=device).to(x_dtype)
340*da0073e9SAndroid Build Coastguard Worker        y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t()
341*da0073e9SAndroid Build Coastguard Worker        out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float))
342*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.0, device=device)
343*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(1.0, device=device)
344*da0073e9SAndroid Build Coastguard Worker        out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype)
345*da0073e9SAndroid Build Coastguard Worker        if out_dtype is not None:
346*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out_dtype, out_fp8.dtype)
347*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_fp32, out_fp8.to(torch.float))
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
350*da0073e9SAndroid Build Coastguard Worker    def test_float8_basics(self, device) -> None:
351*da0073e9SAndroid Build Coastguard Worker        self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16)
352*da0073e9SAndroid Build Coastguard Worker        # hipblaslt does not yet support mixed e4m3_type input
353*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip is None:
354*da0073e9SAndroid Build Coastguard Worker            self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32)
355*da0073e9SAndroid Build Coastguard Worker            self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48)
356*da0073e9SAndroid Build Coastguard Worker        # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported
357*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
358*da0073e9SAndroid Build Coastguard Worker            self._test_tautological_mm(device, e5m2_type, e5m2_type)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        self._test_tautological_mm(device, size=64, out_dtype=torch.float16)
361*da0073e9SAndroid Build Coastguard Worker        self._test_tautological_mm(device, size=96, out_dtype=torch.float32)
362*da0073e9SAndroid Build Coastguard Worker        # hipblaslt does not yet support bfloat16 output
363*da0073e9SAndroid Build Coastguard Worker        if torch.version.hip is None:
364*da0073e9SAndroid Build Coastguard Worker            self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)
365*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
366*da0073e9SAndroid Build Coastguard Worker            self._test_tautological_mm(device, out_dtype=e5m2_type)
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
369*da0073e9SAndroid Build Coastguard Worker    def test_float8_scale(self, device) -> None:
370*da0073e9SAndroid Build Coastguard Worker        size = (16, 16)
371*da0073e9SAndroid Build Coastguard Worker        x = torch.full(size, .5, device=device, dtype=e4m3_type)
372*da0073e9SAndroid Build Coastguard Worker        # hipblaslt does not yet support mixed e4m3_type input
373*da0073e9SAndroid Build Coastguard Worker        y_type = e4m3_type if torch.version.hip else e5m2_type
374*da0073e9SAndroid Build Coastguard Worker        y = torch.full(size, .5, device=device, dtype=y_type).t()
375*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.5, device=device)
376*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(0.66, device=device)
377*da0073e9SAndroid Build Coastguard Worker        out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
378*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
379*da0073e9SAndroid Build Coastguard Worker        out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
380*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_fp8, out_fp8_s)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
383*da0073e9SAndroid Build Coastguard Worker    @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
384*da0073e9SAndroid Build Coastguard Worker    def test_scaled_mm_vs_emulated(self, base_dtype):
385*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
386*da0073e9SAndroid Build Coastguard Worker        input_dtype = e4m3_type
387*da0073e9SAndroid Build Coastguard Worker        output_dtype = base_dtype
388*da0073e9SAndroid Build Coastguard Worker        compare_type = torch.float32
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
391*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker        x_scale = tensor_to_scale(x, input_dtype).float()
394*da0073e9SAndroid Build Coastguard Worker        y_scale = tensor_to_scale(y, input_dtype).float()
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker        x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
397*da0073e9SAndroid Build Coastguard Worker        y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        # Calculate actual F8 mm
400*da0073e9SAndroid Build Coastguard Worker        out_scaled_mm = mm_float8(
401*da0073e9SAndroid Build Coastguard Worker            x_fp8,
402*da0073e9SAndroid Build Coastguard Worker            y_fp8,
403*da0073e9SAndroid Build Coastguard Worker            a_scale=x_scale,
404*da0073e9SAndroid Build Coastguard Worker            b_scale=y_scale,
405*da0073e9SAndroid Build Coastguard Worker            output_dtype=output_dtype
406*da0073e9SAndroid Build Coastguard Worker        )
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker        # Calculate emulated F8 mm
409*da0073e9SAndroid Build Coastguard Worker        out_emulated = mm_float8_emulated(
410*da0073e9SAndroid Build Coastguard Worker            x_fp8,
411*da0073e9SAndroid Build Coastguard Worker            x_scale,
412*da0073e9SAndroid Build Coastguard Worker            y_fp8,
413*da0073e9SAndroid Build Coastguard Worker            y_scale,
414*da0073e9SAndroid Build Coastguard Worker            output_dtype
415*da0073e9SAndroid Build Coastguard Worker        )
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker        if output_dtype != base_dtype:
418*da0073e9SAndroid Build Coastguard Worker            out_scaled_mm = out_scaled_mm.to(compare_type)
419*da0073e9SAndroid Build Coastguard Worker            out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker            out_emulated = out_emulated.to(compare_type)
422*da0073e9SAndroid Build Coastguard Worker            out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        if base_dtype in {torch.bfloat16, torch.float16}:
425*da0073e9SAndroid Build Coastguard Worker            atol, rtol = 7e-2, 7e-2
426*da0073e9SAndroid Build Coastguard Worker        else:
427*da0073e9SAndroid Build Coastguard Worker            atol, rtol = 3e-3, 3e-3
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
432*da0073e9SAndroid Build Coastguard Worker    @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32])
433*da0073e9SAndroid Build Coastguard Worker    def test_scaled_mm_change_stride(self, base_dtype):
434*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
435*da0073e9SAndroid Build Coastguard Worker        input_dtype = e4m3_type
436*da0073e9SAndroid Build Coastguard Worker        output_dtype = base_dtype
437*da0073e9SAndroid Build Coastguard Worker        compare_type = torch.float32
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker        x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype)
440*da0073e9SAndroid Build Coastguard Worker        y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype)
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker        x_scale = tensor_to_scale(x, input_dtype).float()
443*da0073e9SAndroid Build Coastguard Worker        y_scale = tensor_to_scale(y, input_dtype).float()
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker        x_fp8 = to_fp8_saturated(x * x_scale, input_dtype)
446*da0073e9SAndroid Build Coastguard Worker        y_fp8 = to_fp8_saturated(y * y_scale, input_dtype)
447*da0073e9SAndroid Build Coastguard Worker
448*da0073e9SAndroid Build Coastguard Worker        # Calculate actual F8 mm
449*da0073e9SAndroid Build Coastguard Worker        out_scaled_mm = mm_float8(
450*da0073e9SAndroid Build Coastguard Worker            x_fp8,
451*da0073e9SAndroid Build Coastguard Worker            y_fp8,
452*da0073e9SAndroid Build Coastguard Worker            a_scale=x_scale,
453*da0073e9SAndroid Build Coastguard Worker            b_scale=y_scale,
454*da0073e9SAndroid Build Coastguard Worker            output_dtype=output_dtype
455*da0073e9SAndroid Build Coastguard Worker        )
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        # Calculate emulated F8 mm
458*da0073e9SAndroid Build Coastguard Worker        out_emulated = mm_float8_emulated(
459*da0073e9SAndroid Build Coastguard Worker            x_fp8,
460*da0073e9SAndroid Build Coastguard Worker            x_scale,
461*da0073e9SAndroid Build Coastguard Worker            y_fp8,
462*da0073e9SAndroid Build Coastguard Worker            y_scale,
463*da0073e9SAndroid Build Coastguard Worker            output_dtype
464*da0073e9SAndroid Build Coastguard Worker        )
465*da0073e9SAndroid Build Coastguard Worker
466*da0073e9SAndroid Build Coastguard Worker        if output_dtype != base_dtype:
467*da0073e9SAndroid Build Coastguard Worker            out_scaled_mm = out_scaled_mm.to(compare_type)
468*da0073e9SAndroid Build Coastguard Worker            out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype)
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker            out_emulated = out_emulated.to(compare_type)
471*da0073e9SAndroid Build Coastguard Worker            out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype)
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker        if base_dtype in {torch.bfloat16, torch.float16}:
474*da0073e9SAndroid Build Coastguard Worker            atol, rtol = 7e-2, 7e-2
475*da0073e9SAndroid Build Coastguard Worker        else:
476*da0073e9SAndroid Build Coastguard Worker            atol, rtol = 3e-3, 3e-3
477*da0073e9SAndroid Build Coastguard Worker
478*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
481*da0073e9SAndroid Build Coastguard Worker    def test_float8_bias(self, device) -> None:
482*da0073e9SAndroid Build Coastguard Worker        (k, l, m) = (16, 48, 32)
483*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((k, l), device=device).to(e4m3_type)
484*da0073e9SAndroid Build Coastguard Worker        y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
485*da0073e9SAndroid Build Coastguard Worker        bias = torch.full((m,), 4.0, device=device, dtype=torch.half)
486*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.0, device=device)
487*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(1.0, device=device)
488*da0073e9SAndroid Build Coastguard Worker        out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b)
489*da0073e9SAndroid Build Coastguard Worker        outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias)
490*da0073e9SAndroid Build Coastguard Worker        # this fails on ROCm currently because hipblaslt doesn't have amax op
491*da0073e9SAndroid Build Coastguard Worker        out_fp32 = out_fp8.to(torch.float32)
492*da0073e9SAndroid Build Coastguard Worker        outb_fp32 = outb_fp8.to(torch.float32)
493*da0073e9SAndroid Build Coastguard Worker        difference = torch.abs(out_fp32 - outb_fp32)
494*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32))
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
497*da0073e9SAndroid Build Coastguard Worker    @parametrize("bias", [True, False])
498*da0073e9SAndroid Build Coastguard Worker    def test_non_divisible_leading_dim(self, device, bias: bool) -> None:
499*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((17, 16), device=device).to(e4m3_type)
500*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((16, 16), device=device).to(e4m3_type).t()
501*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.0, device=device)
502*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(1.0, device=device)
503*da0073e9SAndroid Build Coastguard Worker        input_bias = None
504*da0073e9SAndroid Build Coastguard Worker        if bias:
505*da0073e9SAndroid Build Coastguard Worker            input_bias = torch.rand((16,), device=device).to(torch.half)
506*da0073e9SAndroid Build Coastguard Worker        _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
509*da0073e9SAndroid Build Coastguard Worker    def test_float8_bias_relu_edgecase(self, device) -> None:
510*da0073e9SAndroid Build Coastguard Worker        (k, l, m) = (16, 48, 32)
511*da0073e9SAndroid Build Coastguard Worker        x = torch.full((k, l), 0.0, device=device).to(e4m3_type)
512*da0073e9SAndroid Build Coastguard Worker        y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t()
513*da0073e9SAndroid Build Coastguard Worker        bias = torch.full((m,), -3.0, device=device, dtype=torch.half)
514*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.0, device=device)
515*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(1.0, device=device)
516*da0073e9SAndroid Build Coastguard Worker        outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias)
517*da0073e9SAndroid Build Coastguard Worker        outb_fp32 = outb_fp8.to(torch.float32)
518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32))
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
521*da0073e9SAndroid Build Coastguard Worker    def test_float32_output_errors_with_bias(self, device) -> None:
522*da0073e9SAndroid Build Coastguard Worker        (k, l, m) = (16, 48, 32)
523*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((k, l), device=device).to(e4m3_type)
524*da0073e9SAndroid Build Coastguard Worker        y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t()
525*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.0, device=device)
526*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(1.0, device=device)
527*da0073e9SAndroid Build Coastguard Worker        bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16)
528*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
529*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
530*da0073e9SAndroid Build Coastguard Worker            "Bias is not supported when out_dtype is set to Float32",
531*da0073e9SAndroid Build Coastguard Worker            lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32),
532*da0073e9SAndroid Build Coastguard Worker        )
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(PLATFORM_SUPPORTS_FP8,
535*da0073e9SAndroid Build Coastguard Worker                     "This test is only for devices with compute capability < 8.9")
536*da0073e9SAndroid Build Coastguard Worker    def test_error_message_fp8_pre_sm89(self, device) -> None:
537*da0073e9SAndroid Build Coastguard Worker        (k, l, m) = (16, 48, 32)
538*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((k, l), device=device).to(e4m3_type)
539*da0073e9SAndroid Build Coastguard Worker        y = torch.rand((m, l), device=device).to(e4m3_type).t()
540*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.0, device=device)
541*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(1.0, device=device)
542*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
543*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
544*da0073e9SAndroid Build Coastguard Worker            r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+",
545*da0073e9SAndroid Build Coastguard Worker            lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32),
546*da0073e9SAndroid Build Coastguard Worker        )
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
549*da0073e9SAndroid Build Coastguard Worker    def test_float8_scale_fast_accum(self, device) -> None:
550*da0073e9SAndroid Build Coastguard Worker        size = (16, 16)
551*da0073e9SAndroid Build Coastguard Worker        x = torch.full(size, .5, device=device, dtype=e4m3_type)
552*da0073e9SAndroid Build Coastguard Worker        # hipblaslt does not yet support mixed e4m3_type input
553*da0073e9SAndroid Build Coastguard Worker        y_type = e4m3_type if torch.version.hip else e5m2_type
554*da0073e9SAndroid Build Coastguard Worker        y = torch.full(size, .5, device=device, dtype=y_type).t()
555*da0073e9SAndroid Build Coastguard Worker        scale_a = torch.tensor(1.5, device=device)
556*da0073e9SAndroid Build Coastguard Worker        scale_b = torch.tensor(0.66, device=device)
557*da0073e9SAndroid Build Coastguard Worker        out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True)
558*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device))
559*da0073e9SAndroid Build Coastguard Worker        out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True)
560*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_fp8, out_fp8_s)
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
563*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm()
564*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_fast_accum", [True, False])
565*da0073e9SAndroid Build Coastguard Worker    def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None:
566*da0073e9SAndroid Build Coastguard Worker        M, K, N = (1024, 512, 2048)
567*da0073e9SAndroid Build Coastguard Worker        fill_value = 0.5
568*da0073e9SAndroid Build Coastguard Worker        x = torch.full((M, K), fill_value, device=device)
569*da0073e9SAndroid Build Coastguard Worker        y = torch.full((N, K), fill_value, device=device)
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32)
572*da0073e9SAndroid Build Coastguard Worker        y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32)
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker        x_fp8 = x.to(torch.float8_e4m3fn)
575*da0073e9SAndroid Build Coastguard Worker        y_fp8 = y.to(torch.float8_e4m3fn).t()
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        out_fp8 = torch._scaled_mm(
578*da0073e9SAndroid Build Coastguard Worker            x_fp8,
579*da0073e9SAndroid Build Coastguard Worker            y_fp8,
580*da0073e9SAndroid Build Coastguard Worker            scale_a=x_scales,
581*da0073e9SAndroid Build Coastguard Worker            scale_b=y_scales,
582*da0073e9SAndroid Build Coastguard Worker            out_dtype=torch.bfloat16,
583*da0073e9SAndroid Build Coastguard Worker            use_fast_accum=use_fast_accum,
584*da0073e9SAndroid Build Coastguard Worker        )
585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
586*da0073e9SAndroid Build Coastguard Worker            out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
587*da0073e9SAndroid Build Coastguard Worker        )
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
590*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm()
591*da0073e9SAndroid Build Coastguard Worker    def test_float8_error_messages(self, device) -> None:
592*da0073e9SAndroid Build Coastguard Worker        M, K, N = (1024, 512, 2048)
593*da0073e9SAndroid Build Coastguard Worker        fill_value = 0.5
594*da0073e9SAndroid Build Coastguard Worker        x = torch.full((M, K), fill_value, device=device)
595*da0073e9SAndroid Build Coastguard Worker        y = torch.full((N, K), fill_value, device=device)
596*da0073e9SAndroid Build Coastguard Worker
597*da0073e9SAndroid Build Coastguard Worker        x_fp8 = x.to(torch.float8_e4m3fn)
598*da0073e9SAndroid Build Coastguard Worker        y_fp8 = y.to(torch.float8_e4m3fn).t()
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
601*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
602*da0073e9SAndroid Build Coastguard Worker            re.escape(
603*da0073e9SAndroid Build Coastguard Worker                "For RowWise scaling, scale_a should be (1024, 1) and scale_b "
604*da0073e9SAndroid Build Coastguard Worker                "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)"
605*da0073e9SAndroid Build Coastguard Worker            ),
606*da0073e9SAndroid Build Coastguard Worker        ):
607*da0073e9SAndroid Build Coastguard Worker            torch._scaled_mm(
608*da0073e9SAndroid Build Coastguard Worker                x_fp8,
609*da0073e9SAndroid Build Coastguard Worker                y_fp8,
610*da0073e9SAndroid Build Coastguard Worker                scale_a=torch.ones((1, 1), device="cuda"),
611*da0073e9SAndroid Build Coastguard Worker                scale_b=torch.ones((1, 2), device="cuda"),
612*da0073e9SAndroid Build Coastguard Worker                out_dtype=torch.bfloat16,
613*da0073e9SAndroid Build Coastguard Worker            )
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
616*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
617*da0073e9SAndroid Build Coastguard Worker            re.escape(
618*da0073e9SAndroid Build Coastguard Worker                " For RowWise scaling, scale_a should be (1024, 1) and scale_b "
619*da0073e9SAndroid Build Coastguard Worker                "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)"
620*da0073e9SAndroid Build Coastguard Worker            ),
621*da0073e9SAndroid Build Coastguard Worker        ):
622*da0073e9SAndroid Build Coastguard Worker            torch._scaled_mm(
623*da0073e9SAndroid Build Coastguard Worker                x_fp8,
624*da0073e9SAndroid Build Coastguard Worker                y_fp8,
625*da0073e9SAndroid Build Coastguard Worker                scale_a=torch.ones((M, 1), device="cuda"),
626*da0073e9SAndroid Build Coastguard Worker                scale_b=torch.ones((1, N + 1), device="cuda"),
627*da0073e9SAndroid Build Coastguard Worker                out_dtype=torch.bfloat16,
628*da0073e9SAndroid Build Coastguard Worker            )
629*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
630*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
631*da0073e9SAndroid Build Coastguard Worker            re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"),
632*da0073e9SAndroid Build Coastguard Worker        ):
633*da0073e9SAndroid Build Coastguard Worker            torch._scaled_mm(
634*da0073e9SAndroid Build Coastguard Worker                x_fp8,
635*da0073e9SAndroid Build Coastguard Worker                y_fp8,
636*da0073e9SAndroid Build Coastguard Worker                scale_a=torch.ones((M), device="cuda"),
637*da0073e9SAndroid Build Coastguard Worker                scale_b=torch.ones((N, N), device="cuda"),
638*da0073e9SAndroid Build Coastguard Worker                out_dtype=torch.bfloat16,
639*da0073e9SAndroid Build Coastguard Worker            )
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
642*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
643*da0073e9SAndroid Build Coastguard Worker            re.escape(
644*da0073e9SAndroid Build Coastguard Worker                "Both scale_a and scale_b must be contiguous for RowWise scaling."
645*da0073e9SAndroid Build Coastguard Worker            ),
646*da0073e9SAndroid Build Coastguard Worker        ):
647*da0073e9SAndroid Build Coastguard Worker            torch._scaled_mm(
648*da0073e9SAndroid Build Coastguard Worker                x_fp8,
649*da0073e9SAndroid Build Coastguard Worker                y_fp8,
650*da0073e9SAndroid Build Coastguard Worker                scale_a=torch.ones((M, 1), device="cuda"),
651*da0073e9SAndroid Build Coastguard Worker                scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2],
652*da0073e9SAndroid Build Coastguard Worker                out_dtype=torch.bfloat16,
653*da0073e9SAndroid Build Coastguard Worker            )
654*da0073e9SAndroid Build Coastguard Worker
655*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
656*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
657*da0073e9SAndroid Build Coastguard Worker            re.escape("For RowWise scaling the second input is required to be a float8_e4m3fn dtype."),
658*da0073e9SAndroid Build Coastguard Worker        ):
659*da0073e9SAndroid Build Coastguard Worker            torch._scaled_mm(
660*da0073e9SAndroid Build Coastguard Worker                x_fp8,
661*da0073e9SAndroid Build Coastguard Worker                y_fp8.to(torch.float8_e5m2),
662*da0073e9SAndroid Build Coastguard Worker                scale_a=torch.ones((M, 1), device="cuda"),
663*da0073e9SAndroid Build Coastguard Worker                scale_b=torch.ones((1, N), device="cuda"),
664*da0073e9SAndroid Build Coastguard Worker                out_dtype=torch.bfloat16,
665*da0073e9SAndroid Build Coastguard Worker            )
666*da0073e9SAndroid Build Coastguard Worker
667*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
668*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific")
669*da0073e9SAndroid Build Coastguard Worker    @skipIfRocm()
670*da0073e9SAndroid Build Coastguard Worker    @parametrize("base_dtype", [torch.bfloat16])
671*da0073e9SAndroid Build Coastguard Worker    def test_scaled_mm_vs_emulated_row_wise(self, base_dtype):
672*da0073e9SAndroid Build Coastguard Worker        torch.manual_seed(42)
673*da0073e9SAndroid Build Coastguard Worker        input_dtype = e4m3_type
674*da0073e9SAndroid Build Coastguard Worker        output_dtype = base_dtype
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(16, 16, device="cuda", dtype=base_dtype)
677*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t()
678*da0073e9SAndroid Build Coastguard Worker
679*da0073e9SAndroid Build Coastguard Worker        x_scales = tensor_to_scale(x, input_dtype, dim=1).float()
680*da0073e9SAndroid Build Coastguard Worker        y_scales = tensor_to_scale(y, input_dtype, dim=0).float()
681*da0073e9SAndroid Build Coastguard Worker
682*da0073e9SAndroid Build Coastguard Worker        x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type)
683*da0073e9SAndroid Build Coastguard Worker        y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type)
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker        # Calculate actual F8 mm
686*da0073e9SAndroid Build Coastguard Worker        out_scaled_mm = mm_float8(
687*da0073e9SAndroid Build Coastguard Worker            x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype
688*da0073e9SAndroid Build Coastguard Worker        )
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        # Calculate emulated F8 mm
691*da0073e9SAndroid Build Coastguard Worker        out_emulated = mm_float8_emulated(
692*da0073e9SAndroid Build Coastguard Worker            x_fp8, x_scales, y_fp8, y_scales, output_dtype
693*da0073e9SAndroid Build Coastguard Worker        )
694*da0073e9SAndroid Build Coastguard Worker
695*da0073e9SAndroid Build Coastguard Worker        if base_dtype in {torch.bfloat16, torch.float16}:
696*da0073e9SAndroid Build Coastguard Worker            atol, rtol = 7e-2, 7e-2
697*da0073e9SAndroid Build Coastguard Worker        else:
698*da0073e9SAndroid Build Coastguard Worker            atol, rtol = 2e-3, 2e-3
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker        torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)
701*da0073e9SAndroid Build Coastguard Worker
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
704*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions")
705*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x")
706*da0073e9SAndroid Build Coastguard Workerclass TestMixedDtypesLinearCuda(TestCase):
707*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float16, torch.bfloat16)
708*da0073e9SAndroid Build Coastguard Worker    def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"):
709*da0073e9SAndroid Build Coastguard Worker        version = _get_torch_cuda_version()
710*da0073e9SAndroid Build Coastguard Worker        if version < (11, 8):
711*da0073e9SAndroid Build Coastguard Worker            self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+")
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker        def run_test(
714*da0073e9SAndroid Build Coastguard Worker            batch_shape,
715*da0073e9SAndroid Build Coastguard Worker            m,
716*da0073e9SAndroid Build Coastguard Worker            n,
717*da0073e9SAndroid Build Coastguard Worker            k,
718*da0073e9SAndroid Build Coastguard Worker            add_bias,
719*da0073e9SAndroid Build Coastguard Worker            activation,
720*da0073e9SAndroid Build Coastguard Worker            dtype,
721*da0073e9SAndroid Build Coastguard Worker            dtypeq,
722*da0073e9SAndroid Build Coastguard Worker            device,
723*da0073e9SAndroid Build Coastguard Worker            rtol,
724*da0073e9SAndroid Build Coastguard Worker            atol,
725*da0073e9SAndroid Build Coastguard Worker        ):
726*da0073e9SAndroid Build Coastguard Worker            if not add_bias and activation != "none":
727*da0073e9SAndroid Build Coastguard Worker                return
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker            val_lo, val_hi = -1, 1
730*da0073e9SAndroid Build Coastguard Worker            valq_lo, valq_hi = -2, 2
731*da0073e9SAndroid Build Coastguard Worker            input = make_tensor(
732*da0073e9SAndroid Build Coastguard Worker                *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device
733*da0073e9SAndroid Build Coastguard Worker            )
734*da0073e9SAndroid Build Coastguard Worker            weight = make_tensor(
735*da0073e9SAndroid Build Coastguard Worker                n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device
736*da0073e9SAndroid Build Coastguard Worker            )
737*da0073e9SAndroid Build Coastguard Worker            scale = make_tensor(
738*da0073e9SAndroid Build Coastguard Worker                (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
739*da0073e9SAndroid Build Coastguard Worker            )
740*da0073e9SAndroid Build Coastguard Worker            bias = (
741*da0073e9SAndroid Build Coastguard Worker                make_tensor(
742*da0073e9SAndroid Build Coastguard Worker                    (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device
743*da0073e9SAndroid Build Coastguard Worker                )
744*da0073e9SAndroid Build Coastguard Worker                if add_bias
745*da0073e9SAndroid Build Coastguard Worker                else None
746*da0073e9SAndroid Build Coastguard Worker            )
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker            input_ref = input.reshape(-1, input.shape[-1])
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker            # First, test plain multiplication.
751*da0073e9SAndroid Build Coastguard Worker            weight_ref = weight.T.to(input.dtype) * scale.view(1, n)
752*da0073e9SAndroid Build Coastguard Worker            weightq = (
753*da0073e9SAndroid Build Coastguard Worker                pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T
754*da0073e9SAndroid Build Coastguard Worker            )
755*da0073e9SAndroid Build Coastguard Worker            output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n)
756*da0073e9SAndroid Build Coastguard Worker            output = torch.ops.aten._mixed_dtypes_linear(
757*da0073e9SAndroid Build Coastguard Worker                input,
758*da0073e9SAndroid Build Coastguard Worker                quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
759*da0073e9SAndroid Build Coastguard Worker                    weightq, dtypeq, transpose=False
760*da0073e9SAndroid Build Coastguard Worker                ),
761*da0073e9SAndroid Build Coastguard Worker                scale,
762*da0073e9SAndroid Build Coastguard Worker            )
763*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
764*da0073e9SAndroid Build Coastguard Worker
765*da0073e9SAndroid Build Coastguard Worker            # Second, test the linear operator itself.
766*da0073e9SAndroid Build Coastguard Worker            weight_ref = weight.to(input.dtype) * scale.view(n, 1)
767*da0073e9SAndroid Build Coastguard Worker            weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight
768*da0073e9SAndroid Build Coastguard Worker            bias_ref = bias.view(1, n) if add_bias else None
769*da0073e9SAndroid Build Coastguard Worker            output_ref = torch.nn.functional.linear(
770*da0073e9SAndroid Build Coastguard Worker                input_ref, weight_ref, bias=bias_ref
771*da0073e9SAndroid Build Coastguard Worker            ).reshape(*input.shape[:-1], n)
772*da0073e9SAndroid Build Coastguard Worker            if activation == "relu":
773*da0073e9SAndroid Build Coastguard Worker                relu = torch.nn.ReLU()
774*da0073e9SAndroid Build Coastguard Worker                output_ref = relu(output_ref)
775*da0073e9SAndroid Build Coastguard Worker            elif activation == "silu":
776*da0073e9SAndroid Build Coastguard Worker                silu = torch.nn.SiLU()
777*da0073e9SAndroid Build Coastguard Worker                output_ref = silu(output_ref)
778*da0073e9SAndroid Build Coastguard Worker            output = torch.ops.aten._mixed_dtypes_linear(
779*da0073e9SAndroid Build Coastguard Worker                input,
780*da0073e9SAndroid Build Coastguard Worker                quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
781*da0073e9SAndroid Build Coastguard Worker                    weightq, dtypeq, transpose=True
782*da0073e9SAndroid Build Coastguard Worker                ),
783*da0073e9SAndroid Build Coastguard Worker                scale,
784*da0073e9SAndroid Build Coastguard Worker                bias=bias,
785*da0073e9SAndroid Build Coastguard Worker                activation=activation,
786*da0073e9SAndroid Build Coastguard Worker            )
787*da0073e9SAndroid Build Coastguard Worker            torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol)
788*da0073e9SAndroid Build Coastguard Worker
789*da0073e9SAndroid Build Coastguard Worker        dtypeqs = [torch.int8, torch.quint4x2]
790*da0073e9SAndroid Build Coastguard Worker        batch_shapes = [[], [2], [2, 1]]
791*da0073e9SAndroid Build Coastguard Worker        shapes = [
792*da0073e9SAndroid Build Coastguard Worker            [8, 64, 64],
793*da0073e9SAndroid Build Coastguard Worker            [8, 64, 128],
794*da0073e9SAndroid Build Coastguard Worker            [8, 128, 64],
795*da0073e9SAndroid Build Coastguard Worker            [8, 128, 128],
796*da0073e9SAndroid Build Coastguard Worker            [8, 128, 192],
797*da0073e9SAndroid Build Coastguard Worker            [8, 128, 256],
798*da0073e9SAndroid Build Coastguard Worker            [8, 256, 128],
799*da0073e9SAndroid Build Coastguard Worker            [8, 256, 384],
800*da0073e9SAndroid Build Coastguard Worker            [8, 384, 256],
801*da0073e9SAndroid Build Coastguard Worker        ]
802*da0073e9SAndroid Build Coastguard Worker        activations = [None, "relu", "silu"]
803*da0073e9SAndroid Build Coastguard Worker        rtol, atol = 1e-3, 1e-3
804*da0073e9SAndroid Build Coastguard Worker        if dtype == torch.bfloat16:
805*da0073e9SAndroid Build Coastguard Worker            rtol, atol = 1e-2, 1e-3
806*da0073e9SAndroid Build Coastguard Worker        for dtypeq, batch_shape, (m, n, k), add_bias, activation in product(
807*da0073e9SAndroid Build Coastguard Worker            dtypeqs, batch_shapes, shapes, (False, True), activations
808*da0073e9SAndroid Build Coastguard Worker        ):
809*da0073e9SAndroid Build Coastguard Worker            run_test(
810*da0073e9SAndroid Build Coastguard Worker                batch_shape,
811*da0073e9SAndroid Build Coastguard Worker                m,
812*da0073e9SAndroid Build Coastguard Worker                n,
813*da0073e9SAndroid Build Coastguard Worker                k,
814*da0073e9SAndroid Build Coastguard Worker                add_bias,
815*da0073e9SAndroid Build Coastguard Worker                activation,
816*da0073e9SAndroid Build Coastguard Worker                dtype,
817*da0073e9SAndroid Build Coastguard Worker                dtypeq,
818*da0073e9SAndroid Build Coastguard Worker                device,
819*da0073e9SAndroid Build Coastguard Worker                rtol,
820*da0073e9SAndroid Build Coastguard Worker                atol,
821*da0073e9SAndroid Build Coastguard Worker            )
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu")
824*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu")
825*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu")
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__':
828*da0073e9SAndroid Build Coastguard Worker    TestCase._default_dtype_check_enabled = True
829*da0073e9SAndroid Build Coastguard Worker    run_tests()
830