1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: linear algebra"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport unittest 4*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 5*da0073e9SAndroid Build Coastguard Workerfrom functools import partial 6*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional 7*da0073e9SAndroid Build Coastguard Workerimport re 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workerfrom torch.quantization._quantized_conversions import ( 12*da0073e9SAndroid Build Coastguard Worker pack_int4_to_int8, 13*da0073e9SAndroid Build Coastguard Worker quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 17*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import ( 18*da0073e9SAndroid Build Coastguard Worker SM53OrLater, 19*da0073e9SAndroid Build Coastguard Worker SM90OrLater, 20*da0073e9SAndroid Build Coastguard Worker _get_torch_cuda_version, 21*da0073e9SAndroid Build Coastguard Worker PLATFORM_SUPPORTS_FP8 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 24*da0073e9SAndroid Build Coastguard Worker dtypes, 25*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 26*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 27*da0073e9SAndroid Build Coastguard Worker tol as xtol, 28*da0073e9SAndroid Build Coastguard Worker toleranceOverride, 29*da0073e9SAndroid Build Coastguard Worker) 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 32*da0073e9SAndroid Build Coastguard Worker IS_ARM64, 33*da0073e9SAndroid Build Coastguard Worker IS_JETSON, 34*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 35*da0073e9SAndroid Build Coastguard Worker parametrize, 36*da0073e9SAndroid Build Coastguard Worker run_tests, 37*da0073e9SAndroid Build Coastguard Worker skipIfRocmVersionLessThan, 38*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 39*da0073e9SAndroid Build Coastguard Worker skipIfRocm, 40*da0073e9SAndroid Build Coastguard Worker TestCase, 41*da0073e9SAndroid Build Coastguard Worker) 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker_IS_SM8X = False 44*da0073e9SAndroid Build Coastguard Workerif torch.cuda.is_available(): 45*da0073e9SAndroid Build Coastguard Worker _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker# Protects against includes accidentally setting the default dtype 48*da0073e9SAndroid Build Coastguard Workerassert torch.get_default_dtype() is torch.float32 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm") 52*da0073e9SAndroid Build Coastguard Workerclass TestMatmulCuda(TestCase): 53*da0073e9SAndroid Build Coastguard Worker def setUp(self): 54*da0073e9SAndroid Build Coastguard Worker super(self.__class__, self).setUp() 55*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = False 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 58*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_tf32 = True 59*da0073e9SAndroid Build Coastguard Worker super(self.__class__, self).tearDown() 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False): 62*da0073e9SAndroid Build Coastguard Worker # 63*da0073e9SAndroid Build Coastguard Worker # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between 64*da0073e9SAndroid Build Coastguard Worker # results from the CUDA invocation of torch.addmm and the CPU invocation 65*da0073e9SAndroid Build Coastguard Worker # (which does not use CUDA backend). 66*da0073e9SAndroid Build Coastguard Worker # 67*da0073e9SAndroid Build Coastguard Worker # Get dims 68*da0073e9SAndroid Build Coastguard Worker n, m, p = (size + 1, size, size + 2) 69*da0073e9SAndroid Build Coastguard Worker # Disable reduced precision reductions in BFloat16 to bypass some kernels 70*da0073e9SAndroid Build Coastguard Worker # which fail the threshold check 71*da0073e9SAndroid Build Coastguard Worker orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction 72*da0073e9SAndroid Build Coastguard Worker orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 73*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision 74*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision 75*da0073e9SAndroid Build Coastguard Worker # Make random tensors on CPU (seed set on common_utils.py import) 76*da0073e9SAndroid Build Coastguard Worker # (Not using numpy because it does not support bfloat16) 77*da0073e9SAndroid Build Coastguard Worker make_arg = partial(make_tensor, dtype=dtype, device="cpu") 78*da0073e9SAndroid Build Coastguard Worker m_beta = make_arg(1) 79*da0073e9SAndroid Build Coastguard Worker m_input = make_arg((n, p)) 80*da0073e9SAndroid Build Coastguard Worker m_1 = make_arg((n, m)) 81*da0073e9SAndroid Build Coastguard Worker m_2 = make_arg((m, p)) 82*da0073e9SAndroid Build Coastguard Worker # *(B)FLOAT16 Special Handling* 83*da0073e9SAndroid Build Coastguard Worker # Backend does not tensorize float16 on CPU, 84*da0073e9SAndroid Build Coastguard Worker # and bloat16 may present accuracy issues, 85*da0073e9SAndroid Build Coastguard Worker # so convert to float32 for these cases 86*da0073e9SAndroid Build Coastguard Worker # (but keep same for other types, e.g. float32 and int*) 87*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16 or dtype == torch.bfloat16: 88*da0073e9SAndroid Build Coastguard Worker m_beta = m_beta.to(dtype=torch.float32) 89*da0073e9SAndroid Build Coastguard Worker m_input = m_input.to(dtype=torch.float32) 90*da0073e9SAndroid Build Coastguard Worker m_1 = m_1.to(dtype=torch.float32) 91*da0073e9SAndroid Build Coastguard Worker m_2 = m_2.to(dtype=torch.float32) 92*da0073e9SAndroid Build Coastguard Worker # Get CPU result 93*da0073e9SAndroid Build Coastguard Worker res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) 94*da0073e9SAndroid Build Coastguard Worker # *(B)FLOAT16 Special Handling*`` 95*da0073e9SAndroid Build Coastguard Worker # Convert back to (b)float16 96*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16 or dtype == torch.bfloat16: 97*da0073e9SAndroid Build Coastguard Worker m_beta = m_beta.to(dtype=dtype) 98*da0073e9SAndroid Build Coastguard Worker m_input = m_input.to(dtype=dtype) 99*da0073e9SAndroid Build Coastguard Worker m_1 = m_1.to(dtype=dtype) 100*da0073e9SAndroid Build Coastguard Worker m_2 = m_2.to(dtype=dtype) 101*da0073e9SAndroid Build Coastguard Worker res_cpu = res_cpu.to(dtype=dtype) 102*da0073e9SAndroid Build Coastguard Worker # Move arg tensors to CUDA 103*da0073e9SAndroid Build Coastguard Worker m_beta = m_beta.to("cuda") 104*da0073e9SAndroid Build Coastguard Worker m_input = m_input.to("cuda") 105*da0073e9SAndroid Build Coastguard Worker m_1 = m_1.to("cuda") 106*da0073e9SAndroid Build Coastguard Worker m_2 = m_2.to("cuda") 107*da0073e9SAndroid Build Coastguard Worker # Get CUDA result 108*da0073e9SAndroid Build Coastguard Worker res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) 109*da0073e9SAndroid Build Coastguard Worker # Move to CPU for comparison 110*da0073e9SAndroid Build Coastguard Worker res_cuda = res_cuda.to("cpu") 111*da0073e9SAndroid Build Coastguard Worker # Compare 112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res_cpu, res_cuda) 113*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 114*da0073e9SAndroid Build Coastguard Worker torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 117*da0073e9SAndroid Build Coastguard Worker @skipIfRocmVersionLessThan((5, 2)) 118*da0073e9SAndroid Build Coastguard Worker # imported 'tol' as 'xtol' to avoid aliasing in code above 119*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1), 120*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), 121*da0073e9SAndroid Build Coastguard Worker torch.float32: xtol(atol=1e-1, rtol=1e-1)}) 122*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16, torch.float32) 123*da0073e9SAndroid Build Coastguard Worker @parametrize("size", [100, 1000, 10000]) 124*da0073e9SAndroid Build Coastguard Worker def test_cublas_addmm(self, size: int, dtype: torch.dtype): 125*da0073e9SAndroid Build Coastguard Worker self.cublas_addmm(size, dtype, False) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 128*da0073e9SAndroid Build Coastguard Worker @skipIfRocmVersionLessThan((5, 2)) 129*da0073e9SAndroid Build Coastguard Worker # imported 'tol' as 'xtol' to avoid aliasing in code above 130*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), 131*da0073e9SAndroid Build Coastguard Worker torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) 132*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 133*da0073e9SAndroid Build Coastguard Worker @parametrize("size", [100, 1000, 10000]) 134*da0073e9SAndroid Build Coastguard Worker def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): 135*da0073e9SAndroid Build Coastguard Worker self.cublas_addmm(size, dtype, True) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 138*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) 139*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16) 140*da0073e9SAndroid Build Coastguard Worker def test_cublas_addmm_alignment(self, dtype): 141*da0073e9SAndroid Build Coastguard Worker device = 'cuda' 142*da0073e9SAndroid Build Coastguard Worker # perturb X, A, or B alignment 143*da0073e9SAndroid Build Coastguard Worker for idx in range(0, 3): 144*da0073e9SAndroid Build Coastguard Worker for offset in range(1, 3): 145*da0073e9SAndroid Build Coastguard Worker offsets = [0, 0, 0] 146*da0073e9SAndroid Build Coastguard Worker offsets[idx] = offset 147*da0073e9SAndroid Build Coastguard Worker x_offset, a_offset, b_offset = offsets 148*da0073e9SAndroid Build Coastguard Worker A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device) 149*da0073e9SAndroid Build Coastguard Worker A = A[a_offset:].reshape(5120, 2560) 150*da0073e9SAndroid Build Coastguard Worker X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device) 151*da0073e9SAndroid Build Coastguard Worker X = X[x_offset:].reshape(26, 1, 2560) 152*da0073e9SAndroid Build Coastguard Worker B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device) 153*da0073e9SAndroid Build Coastguard Worker B = B[b_offset:].reshape(5120) 154*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.linear(X, A, B) 155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 158*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_JETSON, "Too large for Jetson") 159*da0073e9SAndroid Build Coastguard Worker @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) 160*da0073e9SAndroid Build Coastguard Worker @dtypes(*([torch.float32, torch.float16] + 161*da0073e9SAndroid Build Coastguard Worker [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) 162*da0073e9SAndroid Build Coastguard Worker @parametrize( 163*da0073e9SAndroid Build Coastguard Worker "batch_size, N, M, P", 164*da0073e9SAndroid Build Coastguard Worker [(2, 100, 100, 100), 165*da0073e9SAndroid Build Coastguard Worker (2, 1000, 1000, 1000), 166*da0073e9SAndroid Build Coastguard Worker (1, 10000, 1000, 10000), 167*da0073e9SAndroid Build Coastguard Worker (1, 10000, 10000, 10000)], 168*da0073e9SAndroid Build Coastguard Worker name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", 169*da0073e9SAndroid Build Coastguard Worker ) 170*da0073e9SAndroid Build Coastguard Worker @skipIfRocm 171*da0073e9SAndroid Build Coastguard Worker def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): 172*da0073e9SAndroid Build Coastguard Worker cpu_dtype = dtype 173*da0073e9SAndroid Build Coastguard Worker if dtype == torch.float16 or dtype == torch.bfloat16: 174*da0073e9SAndroid Build Coastguard Worker cpu_dtype = torch.float32 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker M1 = torch.rand((N, M), device=device, dtype=dtype) 177*da0073e9SAndroid Build Coastguard Worker M2 = torch.rand((M, P), device=device, dtype=dtype) 178*da0073e9SAndroid Build Coastguard Worker A = torch.rand((N, P), device=device, dtype=dtype) 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker def _convert_to_cpu(t): 181*da0073e9SAndroid Build Coastguard Worker return t.to(device='cpu', dtype=cpu_dtype) 182*da0073e9SAndroid Build Coastguard Worker M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker # linear 185*da0073e9SAndroid Build Coastguard Worker out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) 186*da0073e9SAndroid Build Coastguard Worker out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() 187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_cpu, out1_gpu) 188*da0073e9SAndroid Build Coastguard Worker # test multiply the identity matrix 189*da0073e9SAndroid Build Coastguard Worker if N == M and M == P: 190*da0073e9SAndroid Build Coastguard Worker M2_eye = torch.eye(N, device=device, dtype=dtype) 191*da0073e9SAndroid Build Coastguard Worker out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A)) 192*da0073e9SAndroid Build Coastguard Worker self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker # baddbmm 195*da0073e9SAndroid Build Coastguard Worker def _expand_to_batch(t: torch.Tensor): 196*da0073e9SAndroid Build Coastguard Worker return t.expand((batch_size, ) + t.size()) 197*da0073e9SAndroid Build Coastguard Worker alpha, beta = 1.0, 1.0 198*da0073e9SAndroid Build Coastguard Worker M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu]) 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype) 201*da0073e9SAndroid Build Coastguard Worker out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2_cpu, out2_gpu) 203*da0073e9SAndroid Build Coastguard Worker # test multiply the identity matrix 204*da0073e9SAndroid Build Coastguard Worker if N == M and M == P: 205*da0073e9SAndroid Build Coastguard Worker M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) 206*da0073e9SAndroid Build Coastguard Worker out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha) 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker # cross comparison 210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1_gpu, out2_gpu[0]) 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Workerf8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Workerif torch.version.hip: 216*da0073e9SAndroid Build Coastguard Worker e4m3_type = torch.float8_e4m3fnuz 217*da0073e9SAndroid Build Coastguard Worker e5m2_type = torch.float8_e5m2fnuz 218*da0073e9SAndroid Build Coastguard Worker E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max 219*da0073e9SAndroid Build Coastguard Worker E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max 220*da0073e9SAndroid Build Coastguard Workerelse: 221*da0073e9SAndroid Build Coastguard Worker e4m3_type = torch.float8_e4m3fn 222*da0073e9SAndroid Build Coastguard Worker e5m2_type = torch.float8_e5m2 223*da0073e9SAndroid Build Coastguard Worker E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max 224*da0073e9SAndroid Build Coastguard Worker E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker# avoid division by zero when calculating scale 227*da0073e9SAndroid Build Coastguard WorkerEPS = 1e-12 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Workerdef amax_to_scale( 230*da0073e9SAndroid Build Coastguard Worker amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype 231*da0073e9SAndroid Build Coastguard Worker): 232*da0073e9SAndroid Build Coastguard Worker """ Converts the amax value of a tensor to the fp8 scale. 233*da0073e9SAndroid Build Coastguard Worker Args: 234*da0073e9SAndroid Build Coastguard Worker amax: The amax value of the tensor. 235*da0073e9SAndroid Build Coastguard Worker float8_dtype: the float8 dtype. 236*da0073e9SAndroid Build Coastguard Worker orig_dtype: The original dtype of the tensor. 237*da0073e9SAndroid Build Coastguard Worker """ 238*da0073e9SAndroid Build Coastguard Worker scale = torch.empty_like(amax, dtype=torch.float32) 239*da0073e9SAndroid Build Coastguard Worker if float8_dtype == e4m3_type: 240*da0073e9SAndroid Build Coastguard Worker res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) 241*da0073e9SAndroid Build Coastguard Worker elif float8_dtype == e5m2_type: 242*da0073e9SAndroid Build Coastguard Worker res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) 243*da0073e9SAndroid Build Coastguard Worker else: 244*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker # Ensure the scale is representable in float16, 247*da0073e9SAndroid Build Coastguard Worker # this helps when amax is small. We are assuming that we don't need 248*da0073e9SAndroid Build Coastguard Worker # to care about this for float32/bfloat16 249*da0073e9SAndroid Build Coastguard Worker if orig_dtype is torch.float16: 250*da0073e9SAndroid Build Coastguard Worker res = torch.clamp(res, max=torch.finfo(torch.float16).max) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker scale.copy_(res) 253*da0073e9SAndroid Build Coastguard Worker return scale 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Workerdef tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): 256*da0073e9SAndroid Build Coastguard Worker if dim is None: 257*da0073e9SAndroid Build Coastguard Worker amax = torch.max(torch.abs(x)) 258*da0073e9SAndroid Build Coastguard Worker else: 259*da0073e9SAndroid Build Coastguard Worker amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker return amax_to_scale(amax, float8_dtype, x.dtype) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Workerdef mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: 264*da0073e9SAndroid Build Coastguard Worker # naive implementation: dq -> op -> q 265*da0073e9SAndroid Build Coastguard Worker x_fp32 = x.to(torch.float) / x_scale 266*da0073e9SAndroid Build Coastguard Worker y_fp32 = y.to(torch.float) / y_scale 267*da0073e9SAndroid Build Coastguard Worker out_fp32 = torch.mm(x_fp32, y_fp32) 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker return out_fp32.to(out_dtype) 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Workerdef addmm_float8_unwrapped( 272*da0073e9SAndroid Build Coastguard Worker a_data: torch.Tensor, 273*da0073e9SAndroid Build Coastguard Worker a_scale: torch.Tensor, 274*da0073e9SAndroid Build Coastguard Worker b_data: torch.Tensor, 275*da0073e9SAndroid Build Coastguard Worker b_scale: torch.tensor, 276*da0073e9SAndroid Build Coastguard Worker output_dtype: torch.dtype, 277*da0073e9SAndroid Build Coastguard Worker output_scale: Optional[torch.Tensor], 278*da0073e9SAndroid Build Coastguard Worker bias: Optional[torch.Tensor] = None, 279*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 280*da0073e9SAndroid Build Coastguard Worker a_inverse_scale = a_scale.reciprocal() 281*da0073e9SAndroid Build Coastguard Worker b_inverse_scale = b_scale.reciprocal() 282*da0073e9SAndroid Build Coastguard Worker if output_dtype == torch.float32 and bias is not None: 283*da0073e9SAndroid Build Coastguard Worker # Bias is not supported by _scaled_mm when output is fp32 284*da0073e9SAndroid Build Coastguard Worker output = torch._scaled_mm( 285*da0073e9SAndroid Build Coastguard Worker a_data, 286*da0073e9SAndroid Build Coastguard Worker b_data, 287*da0073e9SAndroid Build Coastguard Worker scale_a=a_inverse_scale, 288*da0073e9SAndroid Build Coastguard Worker scale_b=b_inverse_scale, 289*da0073e9SAndroid Build Coastguard Worker scale_result=output_scale, 290*da0073e9SAndroid Build Coastguard Worker out_dtype=output_dtype, 291*da0073e9SAndroid Build Coastguard Worker ) 292*da0073e9SAndroid Build Coastguard Worker output += bias 293*da0073e9SAndroid Build Coastguard Worker return output 294*da0073e9SAndroid Build Coastguard Worker output = torch._scaled_mm( 295*da0073e9SAndroid Build Coastguard Worker a_data, 296*da0073e9SAndroid Build Coastguard Worker b_data, 297*da0073e9SAndroid Build Coastguard Worker bias=bias, 298*da0073e9SAndroid Build Coastguard Worker scale_a=a_inverse_scale, 299*da0073e9SAndroid Build Coastguard Worker scale_b=b_inverse_scale, 300*da0073e9SAndroid Build Coastguard Worker scale_result=output_scale, 301*da0073e9SAndroid Build Coastguard Worker out_dtype=output_dtype, 302*da0073e9SAndroid Build Coastguard Worker ) 303*da0073e9SAndroid Build Coastguard Worker return output 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Workerdef mm_float8( 306*da0073e9SAndroid Build Coastguard Worker a: torch.Tensor, 307*da0073e9SAndroid Build Coastguard Worker b: torch.Tensor, 308*da0073e9SAndroid Build Coastguard Worker a_scale: torch.Tensor, 309*da0073e9SAndroid Build Coastguard Worker b_scale: torch.Tensor, 310*da0073e9SAndroid Build Coastguard Worker output_dtype: torch.dtype, # output dtype 311*da0073e9SAndroid Build Coastguard Worker output_scale: Optional[torch.Tensor] = None, # output scale, precomputed 312*da0073e9SAndroid Build Coastguard Worker) -> torch.Tensor: 313*da0073e9SAndroid Build Coastguard Worker return addmm_float8_unwrapped( 314*da0073e9SAndroid Build Coastguard Worker a, a_scale, b, b_scale, output_dtype, output_scale 315*da0073e9SAndroid Build Coastguard Worker ) 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Workerdef to_fp8_saturated( 318*da0073e9SAndroid Build Coastguard Worker x: torch.Tensor, 319*da0073e9SAndroid Build Coastguard Worker fp8_dtype: torch.dtype 320*da0073e9SAndroid Build Coastguard Worker): 321*da0073e9SAndroid Build Coastguard Worker if fp8_dtype == e4m3_type: 322*da0073e9SAndroid Build Coastguard Worker x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) 323*da0073e9SAndroid Build Coastguard Worker elif fp8_dtype == e5m2_type: 324*da0073e9SAndroid Build Coastguard Worker x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) 325*da0073e9SAndroid Build Coastguard Worker else: 326*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker return x.to(fp8_dtype) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") 331*da0073e9SAndroid Build Coastguard Workerclass TestFP8MatmulCuda(TestCase): 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 334*da0073e9SAndroid Build Coastguard Worker def _test_tautological_mm(self, device: str = "cuda", 335*da0073e9SAndroid Build Coastguard Worker x_dtype: torch.dtype = e4m3_type, 336*da0073e9SAndroid Build Coastguard Worker y_dtype: torch.dtype = e4m3_type, 337*da0073e9SAndroid Build Coastguard Worker out_dtype: Optional[torch.dtype] = None, 338*da0073e9SAndroid Build Coastguard Worker size: int = 16) -> None: 339*da0073e9SAndroid Build Coastguard Worker x_fp8 = torch.rand(size, size, device=device).to(x_dtype) 340*da0073e9SAndroid Build Coastguard Worker y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() 341*da0073e9SAndroid Build Coastguard Worker out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) 342*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.0, device=device) 343*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(1.0, device=device) 344*da0073e9SAndroid Build Coastguard Worker out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) 345*da0073e9SAndroid Build Coastguard Worker if out_dtype is not None: 346*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_dtype, out_fp8.dtype) 347*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp32, out_fp8.to(torch.float)) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 350*da0073e9SAndroid Build Coastguard Worker def test_float8_basics(self, device) -> None: 351*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) 352*da0073e9SAndroid Build Coastguard Worker # hipblaslt does not yet support mixed e4m3_type input 353*da0073e9SAndroid Build Coastguard Worker if torch.version.hip is None: 354*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) 355*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) 356*da0073e9SAndroid Build Coastguard Worker # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported 357*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 358*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, e5m2_type, e5m2_type) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, size=64, out_dtype=torch.float16) 361*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, size=96, out_dtype=torch.float32) 362*da0073e9SAndroid Build Coastguard Worker # hipblaslt does not yet support bfloat16 output 363*da0073e9SAndroid Build Coastguard Worker if torch.version.hip is None: 364*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) 365*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 366*da0073e9SAndroid Build Coastguard Worker self._test_tautological_mm(device, out_dtype=e5m2_type) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 369*da0073e9SAndroid Build Coastguard Worker def test_float8_scale(self, device) -> None: 370*da0073e9SAndroid Build Coastguard Worker size = (16, 16) 371*da0073e9SAndroid Build Coastguard Worker x = torch.full(size, .5, device=device, dtype=e4m3_type) 372*da0073e9SAndroid Build Coastguard Worker # hipblaslt does not yet support mixed e4m3_type input 373*da0073e9SAndroid Build Coastguard Worker y_type = e4m3_type if torch.version.hip else e5m2_type 374*da0073e9SAndroid Build Coastguard Worker y = torch.full(size, .5, device=device, dtype=y_type).t() 375*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.5, device=device) 376*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(0.66, device=device) 377*da0073e9SAndroid Build Coastguard Worker out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) 378*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) 379*da0073e9SAndroid Build Coastguard Worker out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) 380*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp8, out_fp8_s) 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 383*da0073e9SAndroid Build Coastguard Worker @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) 384*da0073e9SAndroid Build Coastguard Worker def test_scaled_mm_vs_emulated(self, base_dtype): 385*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 386*da0073e9SAndroid Build Coastguard Worker input_dtype = e4m3_type 387*da0073e9SAndroid Build Coastguard Worker output_dtype = base_dtype 388*da0073e9SAndroid Build Coastguard Worker compare_type = torch.float32 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker x = torch.randn(16, 16, device="cuda", dtype=base_dtype) 391*da0073e9SAndroid Build Coastguard Worker y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() 392*da0073e9SAndroid Build Coastguard Worker 393*da0073e9SAndroid Build Coastguard Worker x_scale = tensor_to_scale(x, input_dtype).float() 394*da0073e9SAndroid Build Coastguard Worker y_scale = tensor_to_scale(y, input_dtype).float() 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) 397*da0073e9SAndroid Build Coastguard Worker y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker # Calculate actual F8 mm 400*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = mm_float8( 401*da0073e9SAndroid Build Coastguard Worker x_fp8, 402*da0073e9SAndroid Build Coastguard Worker y_fp8, 403*da0073e9SAndroid Build Coastguard Worker a_scale=x_scale, 404*da0073e9SAndroid Build Coastguard Worker b_scale=y_scale, 405*da0073e9SAndroid Build Coastguard Worker output_dtype=output_dtype 406*da0073e9SAndroid Build Coastguard Worker ) 407*da0073e9SAndroid Build Coastguard Worker 408*da0073e9SAndroid Build Coastguard Worker # Calculate emulated F8 mm 409*da0073e9SAndroid Build Coastguard Worker out_emulated = mm_float8_emulated( 410*da0073e9SAndroid Build Coastguard Worker x_fp8, 411*da0073e9SAndroid Build Coastguard Worker x_scale, 412*da0073e9SAndroid Build Coastguard Worker y_fp8, 413*da0073e9SAndroid Build Coastguard Worker y_scale, 414*da0073e9SAndroid Build Coastguard Worker output_dtype 415*da0073e9SAndroid Build Coastguard Worker ) 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker if output_dtype != base_dtype: 418*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = out_scaled_mm.to(compare_type) 419*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker out_emulated = out_emulated.to(compare_type) 422*da0073e9SAndroid Build Coastguard Worker out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker if base_dtype in {torch.bfloat16, torch.float16}: 425*da0073e9SAndroid Build Coastguard Worker atol, rtol = 7e-2, 7e-2 426*da0073e9SAndroid Build Coastguard Worker else: 427*da0073e9SAndroid Build Coastguard Worker atol, rtol = 3e-3, 3e-3 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 432*da0073e9SAndroid Build Coastguard Worker @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) 433*da0073e9SAndroid Build Coastguard Worker def test_scaled_mm_change_stride(self, base_dtype): 434*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 435*da0073e9SAndroid Build Coastguard Worker input_dtype = e4m3_type 436*da0073e9SAndroid Build Coastguard Worker output_dtype = base_dtype 437*da0073e9SAndroid Build Coastguard Worker compare_type = torch.float32 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) 440*da0073e9SAndroid Build Coastguard Worker y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) 441*da0073e9SAndroid Build Coastguard Worker 442*da0073e9SAndroid Build Coastguard Worker x_scale = tensor_to_scale(x, input_dtype).float() 443*da0073e9SAndroid Build Coastguard Worker y_scale = tensor_to_scale(y, input_dtype).float() 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) 446*da0073e9SAndroid Build Coastguard Worker y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) 447*da0073e9SAndroid Build Coastguard Worker 448*da0073e9SAndroid Build Coastguard Worker # Calculate actual F8 mm 449*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = mm_float8( 450*da0073e9SAndroid Build Coastguard Worker x_fp8, 451*da0073e9SAndroid Build Coastguard Worker y_fp8, 452*da0073e9SAndroid Build Coastguard Worker a_scale=x_scale, 453*da0073e9SAndroid Build Coastguard Worker b_scale=y_scale, 454*da0073e9SAndroid Build Coastguard Worker output_dtype=output_dtype 455*da0073e9SAndroid Build Coastguard Worker ) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker # Calculate emulated F8 mm 458*da0073e9SAndroid Build Coastguard Worker out_emulated = mm_float8_emulated( 459*da0073e9SAndroid Build Coastguard Worker x_fp8, 460*da0073e9SAndroid Build Coastguard Worker x_scale, 461*da0073e9SAndroid Build Coastguard Worker y_fp8, 462*da0073e9SAndroid Build Coastguard Worker y_scale, 463*da0073e9SAndroid Build Coastguard Worker output_dtype 464*da0073e9SAndroid Build Coastguard Worker ) 465*da0073e9SAndroid Build Coastguard Worker 466*da0073e9SAndroid Build Coastguard Worker if output_dtype != base_dtype: 467*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = out_scaled_mm.to(compare_type) 468*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker out_emulated = out_emulated.to(compare_type) 471*da0073e9SAndroid Build Coastguard Worker out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker if base_dtype in {torch.bfloat16, torch.float16}: 474*da0073e9SAndroid Build Coastguard Worker atol, rtol = 7e-2, 7e-2 475*da0073e9SAndroid Build Coastguard Worker else: 476*da0073e9SAndroid Build Coastguard Worker atol, rtol = 3e-3, 3e-3 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 481*da0073e9SAndroid Build Coastguard Worker def test_float8_bias(self, device) -> None: 482*da0073e9SAndroid Build Coastguard Worker (k, l, m) = (16, 48, 32) 483*da0073e9SAndroid Build Coastguard Worker x = torch.ones((k, l), device=device).to(e4m3_type) 484*da0073e9SAndroid Build Coastguard Worker y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() 485*da0073e9SAndroid Build Coastguard Worker bias = torch.full((m,), 4.0, device=device, dtype=torch.half) 486*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.0, device=device) 487*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(1.0, device=device) 488*da0073e9SAndroid Build Coastguard Worker out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) 489*da0073e9SAndroid Build Coastguard Worker outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) 490*da0073e9SAndroid Build Coastguard Worker # this fails on ROCm currently because hipblaslt doesn't have amax op 491*da0073e9SAndroid Build Coastguard Worker out_fp32 = out_fp8.to(torch.float32) 492*da0073e9SAndroid Build Coastguard Worker outb_fp32 = outb_fp8.to(torch.float32) 493*da0073e9SAndroid Build Coastguard Worker difference = torch.abs(out_fp32 - outb_fp32) 494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 497*da0073e9SAndroid Build Coastguard Worker @parametrize("bias", [True, False]) 498*da0073e9SAndroid Build Coastguard Worker def test_non_divisible_leading_dim(self, device, bias: bool) -> None: 499*da0073e9SAndroid Build Coastguard Worker x = torch.rand((17, 16), device=device).to(e4m3_type) 500*da0073e9SAndroid Build Coastguard Worker y = torch.rand((16, 16), device=device).to(e4m3_type).t() 501*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.0, device=device) 502*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(1.0, device=device) 503*da0073e9SAndroid Build Coastguard Worker input_bias = None 504*da0073e9SAndroid Build Coastguard Worker if bias: 505*da0073e9SAndroid Build Coastguard Worker input_bias = torch.rand((16,), device=device).to(torch.half) 506*da0073e9SAndroid Build Coastguard Worker _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 509*da0073e9SAndroid Build Coastguard Worker def test_float8_bias_relu_edgecase(self, device) -> None: 510*da0073e9SAndroid Build Coastguard Worker (k, l, m) = (16, 48, 32) 511*da0073e9SAndroid Build Coastguard Worker x = torch.full((k, l), 0.0, device=device).to(e4m3_type) 512*da0073e9SAndroid Build Coastguard Worker y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() 513*da0073e9SAndroid Build Coastguard Worker bias = torch.full((m,), -3.0, device=device, dtype=torch.half) 514*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.0, device=device) 515*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(1.0, device=device) 516*da0073e9SAndroid Build Coastguard Worker outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) 517*da0073e9SAndroid Build Coastguard Worker outb_fp32 = outb_fp8.to(torch.float32) 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 521*da0073e9SAndroid Build Coastguard Worker def test_float32_output_errors_with_bias(self, device) -> None: 522*da0073e9SAndroid Build Coastguard Worker (k, l, m) = (16, 48, 32) 523*da0073e9SAndroid Build Coastguard Worker x = torch.rand((k, l), device=device).to(e4m3_type) 524*da0073e9SAndroid Build Coastguard Worker y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() 525*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.0, device=device) 526*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(1.0, device=device) 527*da0073e9SAndroid Build Coastguard Worker bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) 528*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 529*da0073e9SAndroid Build Coastguard Worker RuntimeError, 530*da0073e9SAndroid Build Coastguard Worker "Bias is not supported when out_dtype is set to Float32", 531*da0073e9SAndroid Build Coastguard Worker lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), 532*da0073e9SAndroid Build Coastguard Worker ) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(PLATFORM_SUPPORTS_FP8, 535*da0073e9SAndroid Build Coastguard Worker "This test is only for devices with compute capability < 8.9") 536*da0073e9SAndroid Build Coastguard Worker def test_error_message_fp8_pre_sm89(self, device) -> None: 537*da0073e9SAndroid Build Coastguard Worker (k, l, m) = (16, 48, 32) 538*da0073e9SAndroid Build Coastguard Worker x = torch.rand((k, l), device=device).to(e4m3_type) 539*da0073e9SAndroid Build Coastguard Worker y = torch.rand((m, l), device=device).to(e4m3_type).t() 540*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.0, device=device) 541*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(1.0, device=device) 542*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 543*da0073e9SAndroid Build Coastguard Worker RuntimeError, 544*da0073e9SAndroid Build Coastguard Worker r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", 545*da0073e9SAndroid Build Coastguard Worker lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), 546*da0073e9SAndroid Build Coastguard Worker ) 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 549*da0073e9SAndroid Build Coastguard Worker def test_float8_scale_fast_accum(self, device) -> None: 550*da0073e9SAndroid Build Coastguard Worker size = (16, 16) 551*da0073e9SAndroid Build Coastguard Worker x = torch.full(size, .5, device=device, dtype=e4m3_type) 552*da0073e9SAndroid Build Coastguard Worker # hipblaslt does not yet support mixed e4m3_type input 553*da0073e9SAndroid Build Coastguard Worker y_type = e4m3_type if torch.version.hip else e5m2_type 554*da0073e9SAndroid Build Coastguard Worker y = torch.full(size, .5, device=device, dtype=y_type).t() 555*da0073e9SAndroid Build Coastguard Worker scale_a = torch.tensor(1.5, device=device) 556*da0073e9SAndroid Build Coastguard Worker scale_b = torch.tensor(0.66, device=device) 557*da0073e9SAndroid Build Coastguard Worker out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) 558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) 559*da0073e9SAndroid Build Coastguard Worker out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) 560*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_fp8, out_fp8_s) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) 563*da0073e9SAndroid Build Coastguard Worker @skipIfRocm() 564*da0073e9SAndroid Build Coastguard Worker @parametrize("use_fast_accum", [True, False]) 565*da0073e9SAndroid Build Coastguard Worker def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: 566*da0073e9SAndroid Build Coastguard Worker M, K, N = (1024, 512, 2048) 567*da0073e9SAndroid Build Coastguard Worker fill_value = 0.5 568*da0073e9SAndroid Build Coastguard Worker x = torch.full((M, K), fill_value, device=device) 569*da0073e9SAndroid Build Coastguard Worker y = torch.full((N, K), fill_value, device=device) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32) 572*da0073e9SAndroid Build Coastguard Worker y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32) 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker x_fp8 = x.to(torch.float8_e4m3fn) 575*da0073e9SAndroid Build Coastguard Worker y_fp8 = y.to(torch.float8_e4m3fn).t() 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker out_fp8 = torch._scaled_mm( 578*da0073e9SAndroid Build Coastguard Worker x_fp8, 579*da0073e9SAndroid Build Coastguard Worker y_fp8, 580*da0073e9SAndroid Build Coastguard Worker scale_a=x_scales, 581*da0073e9SAndroid Build Coastguard Worker scale_b=y_scales, 582*da0073e9SAndroid Build Coastguard Worker out_dtype=torch.bfloat16, 583*da0073e9SAndroid Build Coastguard Worker use_fast_accum=use_fast_accum, 584*da0073e9SAndroid Build Coastguard Worker ) 585*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 586*da0073e9SAndroid Build Coastguard Worker out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) 587*da0073e9SAndroid Build Coastguard Worker ) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) 590*da0073e9SAndroid Build Coastguard Worker @skipIfRocm() 591*da0073e9SAndroid Build Coastguard Worker def test_float8_error_messages(self, device) -> None: 592*da0073e9SAndroid Build Coastguard Worker M, K, N = (1024, 512, 2048) 593*da0073e9SAndroid Build Coastguard Worker fill_value = 0.5 594*da0073e9SAndroid Build Coastguard Worker x = torch.full((M, K), fill_value, device=device) 595*da0073e9SAndroid Build Coastguard Worker y = torch.full((N, K), fill_value, device=device) 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker x_fp8 = x.to(torch.float8_e4m3fn) 598*da0073e9SAndroid Build Coastguard Worker y_fp8 = y.to(torch.float8_e4m3fn).t() 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 601*da0073e9SAndroid Build Coastguard Worker RuntimeError, 602*da0073e9SAndroid Build Coastguard Worker re.escape( 603*da0073e9SAndroid Build Coastguard Worker "For RowWise scaling, scale_a should be (1024, 1) and scale_b " 604*da0073e9SAndroid Build Coastguard Worker "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" 605*da0073e9SAndroid Build Coastguard Worker ), 606*da0073e9SAndroid Build Coastguard Worker ): 607*da0073e9SAndroid Build Coastguard Worker torch._scaled_mm( 608*da0073e9SAndroid Build Coastguard Worker x_fp8, 609*da0073e9SAndroid Build Coastguard Worker y_fp8, 610*da0073e9SAndroid Build Coastguard Worker scale_a=torch.ones((1, 1), device="cuda"), 611*da0073e9SAndroid Build Coastguard Worker scale_b=torch.ones((1, 2), device="cuda"), 612*da0073e9SAndroid Build Coastguard Worker out_dtype=torch.bfloat16, 613*da0073e9SAndroid Build Coastguard Worker ) 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 616*da0073e9SAndroid Build Coastguard Worker RuntimeError, 617*da0073e9SAndroid Build Coastguard Worker re.escape( 618*da0073e9SAndroid Build Coastguard Worker " For RowWise scaling, scale_a should be (1024, 1) and scale_b " 619*da0073e9SAndroid Build Coastguard Worker "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" 620*da0073e9SAndroid Build Coastguard Worker ), 621*da0073e9SAndroid Build Coastguard Worker ): 622*da0073e9SAndroid Build Coastguard Worker torch._scaled_mm( 623*da0073e9SAndroid Build Coastguard Worker x_fp8, 624*da0073e9SAndroid Build Coastguard Worker y_fp8, 625*da0073e9SAndroid Build Coastguard Worker scale_a=torch.ones((M, 1), device="cuda"), 626*da0073e9SAndroid Build Coastguard Worker scale_b=torch.ones((1, N + 1), device="cuda"), 627*da0073e9SAndroid Build Coastguard Worker out_dtype=torch.bfloat16, 628*da0073e9SAndroid Build Coastguard Worker ) 629*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 630*da0073e9SAndroid Build Coastguard Worker RuntimeError, 631*da0073e9SAndroid Build Coastguard Worker re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), 632*da0073e9SAndroid Build Coastguard Worker ): 633*da0073e9SAndroid Build Coastguard Worker torch._scaled_mm( 634*da0073e9SAndroid Build Coastguard Worker x_fp8, 635*da0073e9SAndroid Build Coastguard Worker y_fp8, 636*da0073e9SAndroid Build Coastguard Worker scale_a=torch.ones((M), device="cuda"), 637*da0073e9SAndroid Build Coastguard Worker scale_b=torch.ones((N, N), device="cuda"), 638*da0073e9SAndroid Build Coastguard Worker out_dtype=torch.bfloat16, 639*da0073e9SAndroid Build Coastguard Worker ) 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 642*da0073e9SAndroid Build Coastguard Worker RuntimeError, 643*da0073e9SAndroid Build Coastguard Worker re.escape( 644*da0073e9SAndroid Build Coastguard Worker "Both scale_a and scale_b must be contiguous for RowWise scaling." 645*da0073e9SAndroid Build Coastguard Worker ), 646*da0073e9SAndroid Build Coastguard Worker ): 647*da0073e9SAndroid Build Coastguard Worker torch._scaled_mm( 648*da0073e9SAndroid Build Coastguard Worker x_fp8, 649*da0073e9SAndroid Build Coastguard Worker y_fp8, 650*da0073e9SAndroid Build Coastguard Worker scale_a=torch.ones((M, 1), device="cuda"), 651*da0073e9SAndroid Build Coastguard Worker scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], 652*da0073e9SAndroid Build Coastguard Worker out_dtype=torch.bfloat16, 653*da0073e9SAndroid Build Coastguard Worker ) 654*da0073e9SAndroid Build Coastguard Worker 655*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 656*da0073e9SAndroid Build Coastguard Worker RuntimeError, 657*da0073e9SAndroid Build Coastguard Worker re.escape("For RowWise scaling the second input is required to be a float8_e4m3fn dtype."), 658*da0073e9SAndroid Build Coastguard Worker ): 659*da0073e9SAndroid Build Coastguard Worker torch._scaled_mm( 660*da0073e9SAndroid Build Coastguard Worker x_fp8, 661*da0073e9SAndroid Build Coastguard Worker y_fp8.to(torch.float8_e5m2), 662*da0073e9SAndroid Build Coastguard Worker scale_a=torch.ones((M, 1), device="cuda"), 663*da0073e9SAndroid Build Coastguard Worker scale_b=torch.ones((1, N), device="cuda"), 664*da0073e9SAndroid Build Coastguard Worker out_dtype=torch.bfloat16, 665*da0073e9SAndroid Build Coastguard Worker ) 666*da0073e9SAndroid Build Coastguard Worker 667*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) 668*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific") 669*da0073e9SAndroid Build Coastguard Worker @skipIfRocm() 670*da0073e9SAndroid Build Coastguard Worker @parametrize("base_dtype", [torch.bfloat16]) 671*da0073e9SAndroid Build Coastguard Worker def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): 672*da0073e9SAndroid Build Coastguard Worker torch.manual_seed(42) 673*da0073e9SAndroid Build Coastguard Worker input_dtype = e4m3_type 674*da0073e9SAndroid Build Coastguard Worker output_dtype = base_dtype 675*da0073e9SAndroid Build Coastguard Worker 676*da0073e9SAndroid Build Coastguard Worker x = torch.randn(16, 16, device="cuda", dtype=base_dtype) 677*da0073e9SAndroid Build Coastguard Worker y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() 678*da0073e9SAndroid Build Coastguard Worker 679*da0073e9SAndroid Build Coastguard Worker x_scales = tensor_to_scale(x, input_dtype, dim=1).float() 680*da0073e9SAndroid Build Coastguard Worker y_scales = tensor_to_scale(y, input_dtype, dim=0).float() 681*da0073e9SAndroid Build Coastguard Worker 682*da0073e9SAndroid Build Coastguard Worker x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) 683*da0073e9SAndroid Build Coastguard Worker y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker # Calculate actual F8 mm 686*da0073e9SAndroid Build Coastguard Worker out_scaled_mm = mm_float8( 687*da0073e9SAndroid Build Coastguard Worker x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype 688*da0073e9SAndroid Build Coastguard Worker ) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker # Calculate emulated F8 mm 691*da0073e9SAndroid Build Coastguard Worker out_emulated = mm_float8_emulated( 692*da0073e9SAndroid Build Coastguard Worker x_fp8, x_scales, y_fp8, y_scales, output_dtype 693*da0073e9SAndroid Build Coastguard Worker ) 694*da0073e9SAndroid Build Coastguard Worker 695*da0073e9SAndroid Build Coastguard Worker if base_dtype in {torch.bfloat16, torch.float16}: 696*da0073e9SAndroid Build Coastguard Worker atol, rtol = 7e-2, 7e-2 697*da0073e9SAndroid Build Coastguard Worker else: 698*da0073e9SAndroid Build Coastguard Worker atol, rtol = 2e-3, 2e-3 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) 701*da0073e9SAndroid Build Coastguard Worker 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") 704*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") 705*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") 706*da0073e9SAndroid Build Coastguard Workerclass TestMixedDtypesLinearCuda(TestCase): 707*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float16, torch.bfloat16) 708*da0073e9SAndroid Build Coastguard Worker def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): 709*da0073e9SAndroid Build Coastguard Worker version = _get_torch_cuda_version() 710*da0073e9SAndroid Build Coastguard Worker if version < (11, 8): 711*da0073e9SAndroid Build Coastguard Worker self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") 712*da0073e9SAndroid Build Coastguard Worker 713*da0073e9SAndroid Build Coastguard Worker def run_test( 714*da0073e9SAndroid Build Coastguard Worker batch_shape, 715*da0073e9SAndroid Build Coastguard Worker m, 716*da0073e9SAndroid Build Coastguard Worker n, 717*da0073e9SAndroid Build Coastguard Worker k, 718*da0073e9SAndroid Build Coastguard Worker add_bias, 719*da0073e9SAndroid Build Coastguard Worker activation, 720*da0073e9SAndroid Build Coastguard Worker dtype, 721*da0073e9SAndroid Build Coastguard Worker dtypeq, 722*da0073e9SAndroid Build Coastguard Worker device, 723*da0073e9SAndroid Build Coastguard Worker rtol, 724*da0073e9SAndroid Build Coastguard Worker atol, 725*da0073e9SAndroid Build Coastguard Worker ): 726*da0073e9SAndroid Build Coastguard Worker if not add_bias and activation != "none": 727*da0073e9SAndroid Build Coastguard Worker return 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker val_lo, val_hi = -1, 1 730*da0073e9SAndroid Build Coastguard Worker valq_lo, valq_hi = -2, 2 731*da0073e9SAndroid Build Coastguard Worker input = make_tensor( 732*da0073e9SAndroid Build Coastguard Worker *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device 733*da0073e9SAndroid Build Coastguard Worker ) 734*da0073e9SAndroid Build Coastguard Worker weight = make_tensor( 735*da0073e9SAndroid Build Coastguard Worker n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device 736*da0073e9SAndroid Build Coastguard Worker ) 737*da0073e9SAndroid Build Coastguard Worker scale = make_tensor( 738*da0073e9SAndroid Build Coastguard Worker (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device 739*da0073e9SAndroid Build Coastguard Worker ) 740*da0073e9SAndroid Build Coastguard Worker bias = ( 741*da0073e9SAndroid Build Coastguard Worker make_tensor( 742*da0073e9SAndroid Build Coastguard Worker (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device 743*da0073e9SAndroid Build Coastguard Worker ) 744*da0073e9SAndroid Build Coastguard Worker if add_bias 745*da0073e9SAndroid Build Coastguard Worker else None 746*da0073e9SAndroid Build Coastguard Worker ) 747*da0073e9SAndroid Build Coastguard Worker 748*da0073e9SAndroid Build Coastguard Worker input_ref = input.reshape(-1, input.shape[-1]) 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker # First, test plain multiplication. 751*da0073e9SAndroid Build Coastguard Worker weight_ref = weight.T.to(input.dtype) * scale.view(1, n) 752*da0073e9SAndroid Build Coastguard Worker weightq = ( 753*da0073e9SAndroid Build Coastguard Worker pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T 754*da0073e9SAndroid Build Coastguard Worker ) 755*da0073e9SAndroid Build Coastguard Worker output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) 756*da0073e9SAndroid Build Coastguard Worker output = torch.ops.aten._mixed_dtypes_linear( 757*da0073e9SAndroid Build Coastguard Worker input, 758*da0073e9SAndroid Build Coastguard Worker quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( 759*da0073e9SAndroid Build Coastguard Worker weightq, dtypeq, transpose=False 760*da0073e9SAndroid Build Coastguard Worker ), 761*da0073e9SAndroid Build Coastguard Worker scale, 762*da0073e9SAndroid Build Coastguard Worker ) 763*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) 764*da0073e9SAndroid Build Coastguard Worker 765*da0073e9SAndroid Build Coastguard Worker # Second, test the linear operator itself. 766*da0073e9SAndroid Build Coastguard Worker weight_ref = weight.to(input.dtype) * scale.view(n, 1) 767*da0073e9SAndroid Build Coastguard Worker weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight 768*da0073e9SAndroid Build Coastguard Worker bias_ref = bias.view(1, n) if add_bias else None 769*da0073e9SAndroid Build Coastguard Worker output_ref = torch.nn.functional.linear( 770*da0073e9SAndroid Build Coastguard Worker input_ref, weight_ref, bias=bias_ref 771*da0073e9SAndroid Build Coastguard Worker ).reshape(*input.shape[:-1], n) 772*da0073e9SAndroid Build Coastguard Worker if activation == "relu": 773*da0073e9SAndroid Build Coastguard Worker relu = torch.nn.ReLU() 774*da0073e9SAndroid Build Coastguard Worker output_ref = relu(output_ref) 775*da0073e9SAndroid Build Coastguard Worker elif activation == "silu": 776*da0073e9SAndroid Build Coastguard Worker silu = torch.nn.SiLU() 777*da0073e9SAndroid Build Coastguard Worker output_ref = silu(output_ref) 778*da0073e9SAndroid Build Coastguard Worker output = torch.ops.aten._mixed_dtypes_linear( 779*da0073e9SAndroid Build Coastguard Worker input, 780*da0073e9SAndroid Build Coastguard Worker quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( 781*da0073e9SAndroid Build Coastguard Worker weightq, dtypeq, transpose=True 782*da0073e9SAndroid Build Coastguard Worker ), 783*da0073e9SAndroid Build Coastguard Worker scale, 784*da0073e9SAndroid Build Coastguard Worker bias=bias, 785*da0073e9SAndroid Build Coastguard Worker activation=activation, 786*da0073e9SAndroid Build Coastguard Worker ) 787*da0073e9SAndroid Build Coastguard Worker torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) 788*da0073e9SAndroid Build Coastguard Worker 789*da0073e9SAndroid Build Coastguard Worker dtypeqs = [torch.int8, torch.quint4x2] 790*da0073e9SAndroid Build Coastguard Worker batch_shapes = [[], [2], [2, 1]] 791*da0073e9SAndroid Build Coastguard Worker shapes = [ 792*da0073e9SAndroid Build Coastguard Worker [8, 64, 64], 793*da0073e9SAndroid Build Coastguard Worker [8, 64, 128], 794*da0073e9SAndroid Build Coastguard Worker [8, 128, 64], 795*da0073e9SAndroid Build Coastguard Worker [8, 128, 128], 796*da0073e9SAndroid Build Coastguard Worker [8, 128, 192], 797*da0073e9SAndroid Build Coastguard Worker [8, 128, 256], 798*da0073e9SAndroid Build Coastguard Worker [8, 256, 128], 799*da0073e9SAndroid Build Coastguard Worker [8, 256, 384], 800*da0073e9SAndroid Build Coastguard Worker [8, 384, 256], 801*da0073e9SAndroid Build Coastguard Worker ] 802*da0073e9SAndroid Build Coastguard Worker activations = [None, "relu", "silu"] 803*da0073e9SAndroid Build Coastguard Worker rtol, atol = 1e-3, 1e-3 804*da0073e9SAndroid Build Coastguard Worker if dtype == torch.bfloat16: 805*da0073e9SAndroid Build Coastguard Worker rtol, atol = 1e-2, 1e-3 806*da0073e9SAndroid Build Coastguard Worker for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( 807*da0073e9SAndroid Build Coastguard Worker dtypeqs, batch_shapes, shapes, (False, True), activations 808*da0073e9SAndroid Build Coastguard Worker ): 809*da0073e9SAndroid Build Coastguard Worker run_test( 810*da0073e9SAndroid Build Coastguard Worker batch_shape, 811*da0073e9SAndroid Build Coastguard Worker m, 812*da0073e9SAndroid Build Coastguard Worker n, 813*da0073e9SAndroid Build Coastguard Worker k, 814*da0073e9SAndroid Build Coastguard Worker add_bias, 815*da0073e9SAndroid Build Coastguard Worker activation, 816*da0073e9SAndroid Build Coastguard Worker dtype, 817*da0073e9SAndroid Build Coastguard Worker dtypeq, 818*da0073e9SAndroid Build Coastguard Worker device, 819*da0073e9SAndroid Build Coastguard Worker rtol, 820*da0073e9SAndroid Build Coastguard Worker atol, 821*da0073e9SAndroid Build Coastguard Worker ) 822*da0073e9SAndroid Build Coastguard Worker 823*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") 824*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") 825*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") 826*da0073e9SAndroid Build Coastguard Worker 827*da0073e9SAndroid Build Coastguard Workerif __name__ == '__main__': 828*da0073e9SAndroid Build Coastguard Worker TestCase._default_dtype_check_enabled = True 829*da0073e9SAndroid Build Coastguard Worker run_tests() 830