1# Owner(s): ["module: linear algebra"] 2 3import unittest 4from itertools import product 5from functools import partial 6from typing import Optional 7import re 8 9import torch 10 11from torch.quantization._quantized_conversions import ( 12 pack_int4_to_int8, 13 quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, 14) 15 16from torch.testing import make_tensor 17from torch.testing._internal.common_cuda import ( 18 SM53OrLater, 19 SM90OrLater, 20 _get_torch_cuda_version, 21 PLATFORM_SUPPORTS_FP8 22) 23from torch.testing._internal.common_device_type import ( 24 dtypes, 25 instantiate_device_type_tests, 26 onlyCUDA, 27 tol as xtol, 28 toleranceOverride, 29) 30 31from torch.testing._internal.common_utils import ( 32 IS_ARM64, 33 IS_JETSON, 34 IS_WINDOWS, 35 parametrize, 36 run_tests, 37 skipIfRocmVersionLessThan, 38 TEST_WITH_ROCM, 39 skipIfRocm, 40 TestCase, 41) 42 43_IS_SM8X = False 44if torch.cuda.is_available(): 45 _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 46 47# Protects against includes accidentally setting the default dtype 48assert torch.get_default_dtype() is torch.float32 49 50 51@unittest.skipIf(IS_ARM64, "Issue with numpy version on arm") 52class TestMatmulCuda(TestCase): 53 def setUp(self): 54 super(self.__class__, self).setUp() 55 torch.backends.cuda.matmul.allow_tf32 = False 56 57 def tearDown(self): 58 torch.backends.cuda.matmul.allow_tf32 = True 59 super(self.__class__, self).tearDown() 60 61 def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False): 62 # 63 # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between 64 # results from the CUDA invocation of torch.addmm and the CPU invocation 65 # (which does not use CUDA backend). 66 # 67 # Get dims 68 n, m, p = (size + 1, size, size + 2) 69 # Disable reduced precision reductions in BFloat16 to bypass some kernels 70 # which fail the threshold check 71 orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction 72 orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction 73 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision 74 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision 75 # Make random tensors on CPU (seed set on common_utils.py import) 76 # (Not using numpy because it does not support bfloat16) 77 make_arg = partial(make_tensor, dtype=dtype, device="cpu") 78 m_beta = make_arg(1) 79 m_input = make_arg((n, p)) 80 m_1 = make_arg((n, m)) 81 m_2 = make_arg((m, p)) 82 # *(B)FLOAT16 Special Handling* 83 # Backend does not tensorize float16 on CPU, 84 # and bloat16 may present accuracy issues, 85 # so convert to float32 for these cases 86 # (but keep same for other types, e.g. float32 and int*) 87 if dtype == torch.float16 or dtype == torch.bfloat16: 88 m_beta = m_beta.to(dtype=torch.float32) 89 m_input = m_input.to(dtype=torch.float32) 90 m_1 = m_1.to(dtype=torch.float32) 91 m_2 = m_2.to(dtype=torch.float32) 92 # Get CPU result 93 res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) 94 # *(B)FLOAT16 Special Handling*`` 95 # Convert back to (b)float16 96 if dtype == torch.float16 or dtype == torch.bfloat16: 97 m_beta = m_beta.to(dtype=dtype) 98 m_input = m_input.to(dtype=dtype) 99 m_1 = m_1.to(dtype=dtype) 100 m_2 = m_2.to(dtype=dtype) 101 res_cpu = res_cpu.to(dtype=dtype) 102 # Move arg tensors to CUDA 103 m_beta = m_beta.to("cuda") 104 m_input = m_input.to("cuda") 105 m_1 = m_1.to("cuda") 106 m_2 = m_2.to("cuda") 107 # Get CUDA result 108 res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) 109 # Move to CPU for comparison 110 res_cuda = res_cuda.to("cpu") 111 # Compare 112 self.assertEqual(res_cpu, res_cuda) 113 torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 114 torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 115 116 @onlyCUDA 117 @skipIfRocmVersionLessThan((5, 2)) 118 # imported 'tol' as 'xtol' to avoid aliasing in code above 119 @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1), 120 torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), 121 torch.float32: xtol(atol=1e-1, rtol=1e-1)}) 122 @dtypes(torch.float16, torch.bfloat16, torch.float32) 123 @parametrize("size", [100, 1000, 10000]) 124 def test_cublas_addmm(self, size: int, dtype: torch.dtype): 125 self.cublas_addmm(size, dtype, False) 126 127 @onlyCUDA 128 @skipIfRocmVersionLessThan((5, 2)) 129 # imported 'tol' as 'xtol' to avoid aliasing in code above 130 @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), 131 torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) 132 @dtypes(torch.float16, torch.bfloat16) 133 @parametrize("size", [100, 1000, 10000]) 134 def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): 135 self.cublas_addmm(size, dtype, True) 136 137 @onlyCUDA 138 @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) 139 @dtypes(torch.float16) 140 def test_cublas_addmm_alignment(self, dtype): 141 device = 'cuda' 142 # perturb X, A, or B alignment 143 for idx in range(0, 3): 144 for offset in range(1, 3): 145 offsets = [0, 0, 0] 146 offsets[idx] = offset 147 x_offset, a_offset, b_offset = offsets 148 A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device) 149 A = A[a_offset:].reshape(5120, 2560) 150 X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device) 151 X = X[x_offset:].reshape(26, 1, 2560) 152 B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device) 153 B = B[b_offset:].reshape(5120) 154 out = torch.nn.functional.linear(X, A, B) 155 self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) 156 157 @onlyCUDA 158 @unittest.skipIf(IS_JETSON, "Too large for Jetson") 159 @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) 160 @dtypes(*([torch.float32, torch.float16] + 161 [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) 162 @parametrize( 163 "batch_size, N, M, P", 164 [(2, 100, 100, 100), 165 (2, 1000, 1000, 1000), 166 (1, 10000, 1000, 10000), 167 (1, 10000, 10000, 10000)], 168 name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", 169 ) 170 @skipIfRocm 171 def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): 172 cpu_dtype = dtype 173 if dtype == torch.float16 or dtype == torch.bfloat16: 174 cpu_dtype = torch.float32 175 176 M1 = torch.rand((N, M), device=device, dtype=dtype) 177 M2 = torch.rand((M, P), device=device, dtype=dtype) 178 A = torch.rand((N, P), device=device, dtype=dtype) 179 180 def _convert_to_cpu(t): 181 return t.to(device='cpu', dtype=cpu_dtype) 182 M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) 183 184 # linear 185 out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) 186 out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() 187 self.assertEqual(out1_cpu, out1_gpu) 188 # test multiply the identity matrix 189 if N == M and M == P: 190 M2_eye = torch.eye(N, device=device, dtype=dtype) 191 out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A)) 192 self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) 193 194 # baddbmm 195 def _expand_to_batch(t: torch.Tensor): 196 return t.expand((batch_size, ) + t.size()) 197 alpha, beta = 1.0, 1.0 198 M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu]) 199 200 out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype) 201 out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() 202 self.assertEqual(out2_cpu, out2_gpu) 203 # test multiply the identity matrix 204 if N == M and M == P: 205 M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) 206 out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha) 207 self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) 208 209 # cross comparison 210 self.assertEqual(out1_gpu, out2_gpu[0]) 211 212 213f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" 214 215if torch.version.hip: 216 e4m3_type = torch.float8_e4m3fnuz 217 e5m2_type = torch.float8_e5m2fnuz 218 E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max 219 E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max 220else: 221 e4m3_type = torch.float8_e4m3fn 222 e5m2_type = torch.float8_e5m2 223 E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max 224 E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max 225 226# avoid division by zero when calculating scale 227EPS = 1e-12 228 229def amax_to_scale( 230 amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype 231): 232 """ Converts the amax value of a tensor to the fp8 scale. 233 Args: 234 amax: The amax value of the tensor. 235 float8_dtype: the float8 dtype. 236 orig_dtype: The original dtype of the tensor. 237 """ 238 scale = torch.empty_like(amax, dtype=torch.float32) 239 if float8_dtype == e4m3_type: 240 res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) 241 elif float8_dtype == e5m2_type: 242 res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) 243 else: 244 raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") 245 246 # Ensure the scale is representable in float16, 247 # this helps when amax is small. We are assuming that we don't need 248 # to care about this for float32/bfloat16 249 if orig_dtype is torch.float16: 250 res = torch.clamp(res, max=torch.finfo(torch.float16).max) 251 252 scale.copy_(res) 253 return scale 254 255def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype, dim=None): 256 if dim is None: 257 amax = torch.max(torch.abs(x)) 258 else: 259 amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values 260 261 return amax_to_scale(amax, float8_dtype, x.dtype) 262 263def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype) -> torch.Tensor: 264 # naive implementation: dq -> op -> q 265 x_fp32 = x.to(torch.float) / x_scale 266 y_fp32 = y.to(torch.float) / y_scale 267 out_fp32 = torch.mm(x_fp32, y_fp32) 268 269 return out_fp32.to(out_dtype) 270 271def addmm_float8_unwrapped( 272 a_data: torch.Tensor, 273 a_scale: torch.Tensor, 274 b_data: torch.Tensor, 275 b_scale: torch.tensor, 276 output_dtype: torch.dtype, 277 output_scale: Optional[torch.Tensor], 278 bias: Optional[torch.Tensor] = None, 279) -> torch.Tensor: 280 a_inverse_scale = a_scale.reciprocal() 281 b_inverse_scale = b_scale.reciprocal() 282 if output_dtype == torch.float32 and bias is not None: 283 # Bias is not supported by _scaled_mm when output is fp32 284 output = torch._scaled_mm( 285 a_data, 286 b_data, 287 scale_a=a_inverse_scale, 288 scale_b=b_inverse_scale, 289 scale_result=output_scale, 290 out_dtype=output_dtype, 291 ) 292 output += bias 293 return output 294 output = torch._scaled_mm( 295 a_data, 296 b_data, 297 bias=bias, 298 scale_a=a_inverse_scale, 299 scale_b=b_inverse_scale, 300 scale_result=output_scale, 301 out_dtype=output_dtype, 302 ) 303 return output 304 305def mm_float8( 306 a: torch.Tensor, 307 b: torch.Tensor, 308 a_scale: torch.Tensor, 309 b_scale: torch.Tensor, 310 output_dtype: torch.dtype, # output dtype 311 output_scale: Optional[torch.Tensor] = None, # output scale, precomputed 312) -> torch.Tensor: 313 return addmm_float8_unwrapped( 314 a, a_scale, b, b_scale, output_dtype, output_scale 315 ) 316 317def to_fp8_saturated( 318 x: torch.Tensor, 319 fp8_dtype: torch.dtype 320): 321 if fp8_dtype == e4m3_type: 322 x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) 323 elif fp8_dtype == e5m2_type: 324 x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) 325 else: 326 raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") 327 328 return x.to(fp8_dtype) 329 330@unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") 331class TestFP8MatmulCuda(TestCase): 332 333 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 334 def _test_tautological_mm(self, device: str = "cuda", 335 x_dtype: torch.dtype = e4m3_type, 336 y_dtype: torch.dtype = e4m3_type, 337 out_dtype: Optional[torch.dtype] = None, 338 size: int = 16) -> None: 339 x_fp8 = torch.rand(size, size, device=device).to(x_dtype) 340 y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() 341 out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) 342 scale_a = torch.tensor(1.0, device=device) 343 scale_b = torch.tensor(1.0, device=device) 344 out_fp8 = torch._scaled_mm(x_fp8, y_fp8, scale_a, scale_b, out_dtype=out_dtype) 345 if out_dtype is not None: 346 self.assertEqual(out_dtype, out_fp8.dtype) 347 self.assertEqual(out_fp32, out_fp8.to(torch.float)) 348 349 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 350 def test_float8_basics(self, device) -> None: 351 self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) 352 # hipblaslt does not yet support mixed e4m3_type input 353 if torch.version.hip is None: 354 self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) 355 self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) 356 # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported 357 with self.assertRaises(RuntimeError): 358 self._test_tautological_mm(device, e5m2_type, e5m2_type) 359 360 self._test_tautological_mm(device, size=64, out_dtype=torch.float16) 361 self._test_tautological_mm(device, size=96, out_dtype=torch.float32) 362 # hipblaslt does not yet support bfloat16 output 363 if torch.version.hip is None: 364 self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) 365 with self.assertRaises(RuntimeError): 366 self._test_tautological_mm(device, out_dtype=e5m2_type) 367 368 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 369 def test_float8_scale(self, device) -> None: 370 size = (16, 16) 371 x = torch.full(size, .5, device=device, dtype=e4m3_type) 372 # hipblaslt does not yet support mixed e4m3_type input 373 y_type = e4m3_type if torch.version.hip else e5m2_type 374 y = torch.full(size, .5, device=device, dtype=y_type).t() 375 scale_a = torch.tensor(1.5, device=device) 376 scale_b = torch.tensor(0.66, device=device) 377 out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) 378 self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) 379 out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) 380 self.assertEqual(out_fp8, out_fp8_s) 381 382 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 383 @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) 384 def test_scaled_mm_vs_emulated(self, base_dtype): 385 torch.manual_seed(42) 386 input_dtype = e4m3_type 387 output_dtype = base_dtype 388 compare_type = torch.float32 389 390 x = torch.randn(16, 16, device="cuda", dtype=base_dtype) 391 y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() 392 393 x_scale = tensor_to_scale(x, input_dtype).float() 394 y_scale = tensor_to_scale(y, input_dtype).float() 395 396 x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) 397 y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) 398 399 # Calculate actual F8 mm 400 out_scaled_mm = mm_float8( 401 x_fp8, 402 y_fp8, 403 a_scale=x_scale, 404 b_scale=y_scale, 405 output_dtype=output_dtype 406 ) 407 408 # Calculate emulated F8 mm 409 out_emulated = mm_float8_emulated( 410 x_fp8, 411 x_scale, 412 y_fp8, 413 y_scale, 414 output_dtype 415 ) 416 417 if output_dtype != base_dtype: 418 out_scaled_mm = out_scaled_mm.to(compare_type) 419 out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) 420 421 out_emulated = out_emulated.to(compare_type) 422 out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) 423 424 if base_dtype in {torch.bfloat16, torch.float16}: 425 atol, rtol = 7e-2, 7e-2 426 else: 427 atol, rtol = 3e-3, 3e-3 428 429 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) 430 431 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 432 @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) 433 def test_scaled_mm_change_stride(self, base_dtype): 434 torch.manual_seed(42) 435 input_dtype = e4m3_type 436 output_dtype = base_dtype 437 compare_type = torch.float32 438 439 x = torch.empty_strided((16, 16), (16, 1), device="cuda", dtype=base_dtype) 440 y = torch.empty_strided((16, 32), (1, 64), device="cuda", dtype=base_dtype) 441 442 x_scale = tensor_to_scale(x, input_dtype).float() 443 y_scale = tensor_to_scale(y, input_dtype).float() 444 445 x_fp8 = to_fp8_saturated(x * x_scale, input_dtype) 446 y_fp8 = to_fp8_saturated(y * y_scale, input_dtype) 447 448 # Calculate actual F8 mm 449 out_scaled_mm = mm_float8( 450 x_fp8, 451 y_fp8, 452 a_scale=x_scale, 453 b_scale=y_scale, 454 output_dtype=output_dtype 455 ) 456 457 # Calculate emulated F8 mm 458 out_emulated = mm_float8_emulated( 459 x_fp8, 460 x_scale, 461 y_fp8, 462 y_scale, 463 output_dtype 464 ) 465 466 if output_dtype != base_dtype: 467 out_scaled_mm = out_scaled_mm.to(compare_type) 468 out_scaled_mm = out_scaled_mm / tensor_to_scale(out_scaled_mm, input_dtype) 469 470 out_emulated = out_emulated.to(compare_type) 471 out_emulated = out_emulated / tensor_to_scale(out_emulated, input_dtype) 472 473 if base_dtype in {torch.bfloat16, torch.float16}: 474 atol, rtol = 7e-2, 7e-2 475 else: 476 atol, rtol = 3e-3, 3e-3 477 478 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) 479 480 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 481 def test_float8_bias(self, device) -> None: 482 (k, l, m) = (16, 48, 32) 483 x = torch.ones((k, l), device=device).to(e4m3_type) 484 y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() 485 bias = torch.full((m,), 4.0, device=device, dtype=torch.half) 486 scale_a = torch.tensor(1.0, device=device) 487 scale_b = torch.tensor(1.0, device=device) 488 out_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) 489 outb_fp8 = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, bias=bias) 490 # this fails on ROCm currently because hipblaslt doesn't have amax op 491 out_fp32 = out_fp8.to(torch.float32) 492 outb_fp32 = outb_fp8.to(torch.float32) 493 difference = torch.abs(out_fp32 - outb_fp32) 494 self.assertEqual(difference, torch.tensor(4.0, device=device).expand_as(out_fp32)) 495 496 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 497 @parametrize("bias", [True, False]) 498 def test_non_divisible_leading_dim(self, device, bias: bool) -> None: 499 x = torch.rand((17, 16), device=device).to(e4m3_type) 500 y = torch.rand((16, 16), device=device).to(e4m3_type).t() 501 scale_a = torch.tensor(1.0, device=device) 502 scale_b = torch.tensor(1.0, device=device) 503 input_bias = None 504 if bias: 505 input_bias = torch.rand((16,), device=device).to(torch.half) 506 _ = torch._scaled_mm(x, y, scale_a, scale_b, bias=input_bias) 507 508 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 509 def test_float8_bias_relu_edgecase(self, device) -> None: 510 (k, l, m) = (16, 48, 32) 511 x = torch.full((k, l), 0.0, device=device).to(e4m3_type) 512 y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() 513 bias = torch.full((m,), -3.0, device=device, dtype=torch.half) 514 scale_a = torch.tensor(1.0, device=device) 515 scale_b = torch.tensor(1.0, device=device) 516 outb_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, bias=bias) 517 outb_fp32 = outb_fp8.to(torch.float32) 518 self.assertEqual(outb_fp32, torch.tensor(-3.0, device=device).expand_as(outb_fp32)) 519 520 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 521 def test_float32_output_errors_with_bias(self, device) -> None: 522 (k, l, m) = (16, 48, 32) 523 x = torch.rand((k, l), device=device).to(e4m3_type) 524 y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() 525 scale_a = torch.tensor(1.0, device=device) 526 scale_b = torch.tensor(1.0, device=device) 527 bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) 528 self.assertRaisesRegex( 529 RuntimeError, 530 "Bias is not supported when out_dtype is set to Float32", 531 lambda: torch._scaled_mm(x, y, scale_a, scale_b, bias=bias, out_dtype=torch.float32), 532 ) 533 534 @unittest.skipIf(PLATFORM_SUPPORTS_FP8, 535 "This test is only for devices with compute capability < 8.9") 536 def test_error_message_fp8_pre_sm89(self, device) -> None: 537 (k, l, m) = (16, 48, 32) 538 x = torch.rand((k, l), device=device).to(e4m3_type) 539 y = torch.rand((m, l), device=device).to(e4m3_type).t() 540 scale_a = torch.tensor(1.0, device=device) 541 scale_b = torch.tensor(1.0, device=device) 542 self.assertRaisesRegex( 543 RuntimeError, 544 r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", 545 lambda: torch._scaled_mm(x, y, scale_a, scale_b, out_dtype=torch.float32), 546 ) 547 548 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 549 def test_float8_scale_fast_accum(self, device) -> None: 550 size = (16, 16) 551 x = torch.full(size, .5, device=device, dtype=e4m3_type) 552 # hipblaslt does not yet support mixed e4m3_type input 553 y_type = e4m3_type if torch.version.hip else e5m2_type 554 y = torch.full(size, .5, device=device, dtype=y_type).t() 555 scale_a = torch.tensor(1.5, device=device) 556 scale_b = torch.tensor(0.66, device=device) 557 out_fp8 = torch._scaled_mm(x, y, scale_a, scale_b, use_fast_accum=True) 558 self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) 559 out_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) 560 self.assertEqual(out_fp8, out_fp8_s) 561 562 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) 563 @skipIfRocm() 564 @parametrize("use_fast_accum", [True, False]) 565 def test_float8_rowwise_scaling_sanity(self, device, use_fast_accum: bool) -> None: 566 M, K, N = (1024, 512, 2048) 567 fill_value = 0.5 568 x = torch.full((M, K), fill_value, device=device) 569 y = torch.full((N, K), fill_value, device=device) 570 571 x_scales = torch.ones((x.shape[0], 1), device=device, dtype=torch.float32) 572 y_scales = torch.ones((1, y.shape[0]), device=device, dtype=torch.float32) 573 574 x_fp8 = x.to(torch.float8_e4m3fn) 575 y_fp8 = y.to(torch.float8_e4m3fn).t() 576 577 out_fp8 = torch._scaled_mm( 578 x_fp8, 579 y_fp8, 580 scale_a=x_scales, 581 scale_b=y_scales, 582 out_dtype=torch.bfloat16, 583 use_fast_accum=use_fast_accum, 584 ) 585 self.assertEqual( 586 out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) 587 ) 588 589 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) 590 @skipIfRocm() 591 def test_float8_error_messages(self, device) -> None: 592 M, K, N = (1024, 512, 2048) 593 fill_value = 0.5 594 x = torch.full((M, K), fill_value, device=device) 595 y = torch.full((N, K), fill_value, device=device) 596 597 x_fp8 = x.to(torch.float8_e4m3fn) 598 y_fp8 = y.to(torch.float8_e4m3fn).t() 599 600 with self.assertRaisesRegex( 601 RuntimeError, 602 re.escape( 603 "For RowWise scaling, scale_a should be (1024, 1) and scale_b " 604 "should be (1, 2048). Got scale_a.size()=(1, 1) and scale_b.size()=(1, 2)" 605 ), 606 ): 607 torch._scaled_mm( 608 x_fp8, 609 y_fp8, 610 scale_a=torch.ones((1, 1), device="cuda"), 611 scale_b=torch.ones((1, 2), device="cuda"), 612 out_dtype=torch.bfloat16, 613 ) 614 615 with self.assertRaisesRegex( 616 RuntimeError, 617 re.escape( 618 " For RowWise scaling, scale_a should be (1024, 1) and scale_b " 619 "should be (1, 2048). Got scale_a.size()=(1024, 1) and scale_b.size()=(1, 2049)" 620 ), 621 ): 622 torch._scaled_mm( 623 x_fp8, 624 y_fp8, 625 scale_a=torch.ones((M, 1), device="cuda"), 626 scale_b=torch.ones((1, N + 1), device="cuda"), 627 out_dtype=torch.bfloat16, 628 ) 629 with self.assertRaisesRegex( 630 RuntimeError, 631 re.escape("For non-TensorWise scaling, scale tensors must be 2-dimensional"), 632 ): 633 torch._scaled_mm( 634 x_fp8, 635 y_fp8, 636 scale_a=torch.ones((M), device="cuda"), 637 scale_b=torch.ones((N, N), device="cuda"), 638 out_dtype=torch.bfloat16, 639 ) 640 641 with self.assertRaisesRegex( 642 RuntimeError, 643 re.escape( 644 "Both scale_a and scale_b must be contiguous for RowWise scaling." 645 ), 646 ): 647 torch._scaled_mm( 648 x_fp8, 649 y_fp8, 650 scale_a=torch.ones((M, 1), device="cuda"), 651 scale_b=torch.ones((1, N * 2), device="cuda")[:, ::2], 652 out_dtype=torch.bfloat16, 653 ) 654 655 with self.assertRaisesRegex( 656 RuntimeError, 657 re.escape("For RowWise scaling the second input is required to be a float8_e4m3fn dtype."), 658 ): 659 torch._scaled_mm( 660 x_fp8, 661 y_fp8.to(torch.float8_e5m2), 662 scale_a=torch.ones((M, 1), device="cuda"), 663 scale_b=torch.ones((1, N), device="cuda"), 664 out_dtype=torch.bfloat16, 665 ) 666 667 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg) 668 @unittest.skipIf(not SM90OrLater, "rowwise implementation is currently sm90 specific") 669 @skipIfRocm() 670 @parametrize("base_dtype", [torch.bfloat16]) 671 def test_scaled_mm_vs_emulated_row_wise(self, base_dtype): 672 torch.manual_seed(42) 673 input_dtype = e4m3_type 674 output_dtype = base_dtype 675 676 x = torch.randn(16, 16, device="cuda", dtype=base_dtype) 677 y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() 678 679 x_scales = tensor_to_scale(x, input_dtype, dim=1).float() 680 y_scales = tensor_to_scale(y, input_dtype, dim=0).float() 681 682 x_fp8 = to_fp8_saturated(x * x_scales, e4m3_type) 683 y_fp8 = to_fp8_saturated(y * y_scales, e4m3_type) 684 685 # Calculate actual F8 mm 686 out_scaled_mm = mm_float8( 687 x_fp8, y_fp8, a_scale=x_scales, b_scale=y_scales, output_dtype=output_dtype 688 ) 689 690 # Calculate emulated F8 mm 691 out_emulated = mm_float8_emulated( 692 x_fp8, x_scales, y_fp8, y_scales, output_dtype 693 ) 694 695 if base_dtype in {torch.bfloat16, torch.float16}: 696 atol, rtol = 7e-2, 7e-2 697 else: 698 atol, rtol = 2e-3, 2e-3 699 700 torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) 701 702 703@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") 704@unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") 705@unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") 706class TestMixedDtypesLinearCuda(TestCase): 707 @dtypes(torch.float16, torch.bfloat16) 708 def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): 709 version = _get_torch_cuda_version() 710 if version < (11, 8): 711 self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") 712 713 def run_test( 714 batch_shape, 715 m, 716 n, 717 k, 718 add_bias, 719 activation, 720 dtype, 721 dtypeq, 722 device, 723 rtol, 724 atol, 725 ): 726 if not add_bias and activation != "none": 727 return 728 729 val_lo, val_hi = -1, 1 730 valq_lo, valq_hi = -2, 2 731 input = make_tensor( 732 *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device 733 ) 734 weight = make_tensor( 735 n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device 736 ) 737 scale = make_tensor( 738 (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device 739 ) 740 bias = ( 741 make_tensor( 742 (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device 743 ) 744 if add_bias 745 else None 746 ) 747 748 input_ref = input.reshape(-1, input.shape[-1]) 749 750 # First, test plain multiplication. 751 weight_ref = weight.T.to(input.dtype) * scale.view(1, n) 752 weightq = ( 753 pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T 754 ) 755 output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) 756 output = torch.ops.aten._mixed_dtypes_linear( 757 input, 758 quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( 759 weightq, dtypeq, transpose=False 760 ), 761 scale, 762 ) 763 torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) 764 765 # Second, test the linear operator itself. 766 weight_ref = weight.to(input.dtype) * scale.view(n, 1) 767 weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight 768 bias_ref = bias.view(1, n) if add_bias else None 769 output_ref = torch.nn.functional.linear( 770 input_ref, weight_ref, bias=bias_ref 771 ).reshape(*input.shape[:-1], n) 772 if activation == "relu": 773 relu = torch.nn.ReLU() 774 output_ref = relu(output_ref) 775 elif activation == "silu": 776 silu = torch.nn.SiLU() 777 output_ref = silu(output_ref) 778 output = torch.ops.aten._mixed_dtypes_linear( 779 input, 780 quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( 781 weightq, dtypeq, transpose=True 782 ), 783 scale, 784 bias=bias, 785 activation=activation, 786 ) 787 torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) 788 789 dtypeqs = [torch.int8, torch.quint4x2] 790 batch_shapes = [[], [2], [2, 1]] 791 shapes = [ 792 [8, 64, 64], 793 [8, 64, 128], 794 [8, 128, 64], 795 [8, 128, 128], 796 [8, 128, 192], 797 [8, 128, 256], 798 [8, 256, 128], 799 [8, 256, 384], 800 [8, 384, 256], 801 ] 802 activations = [None, "relu", "silu"] 803 rtol, atol = 1e-3, 1e-3 804 if dtype == torch.bfloat16: 805 rtol, atol = 1e-2, 1e-3 806 for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( 807 dtypeqs, batch_shapes, shapes, (False, True), activations 808 ): 809 run_test( 810 batch_shape, 811 m, 812 n, 813 k, 814 add_bias, 815 activation, 816 dtype, 817 dtypeq, 818 device, 819 rtol, 820 atol, 821 ) 822 823instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") 824instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") 825instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") 826 827if __name__ == '__main__': 828 TestCase._default_dtype_check_enabled = True 829 run_tests() 830