1# 2# Copyright (c) 2024 Apple Inc. All rights reserved. 3# Provided subject to the LICENSE file in the top level directory. 4# 5 6import inspect 7 8import unittest 9 10from typing import Tuple 11 12import torch 13from executorch.backends.apple.mps.test.test_mps_utils import TestMPS 14 15 16class TestLinear(TestMPS): 17 @unittest.skip("Dynamic shapes not supported in MPS backend") 18 def test_fp16_linear(self): 19 for use_bias in (True, False): 20 for num_batch_dims in range(1, 3): 21 self._test_linear( 22 lambda in_size, out_size: torch.nn.Linear( 23 in_size, out_size, bias=use_bias # noqa 24 ), 25 num_batch_dims=num_batch_dims, 26 uses_bias=use_bias, 27 dtype=torch.float16, 28 atol=5e-2, 29 ) 30 31 @unittest.skip("Dynamic shapes not supported in MPS backend") 32 def test_fp32_linear(self): 33 for use_bias in (True, False): 34 for num_batch_dims in range(1, 3): 35 self._test_linear( 36 lambda in_size, out_size: torch.nn.Linear( 37 in_size, out_size, bias=use_bias # noqa 38 ), 39 uses_bias=use_bias, 40 num_batch_dims=num_batch_dims, 41 ) 42 43 @unittest.skip("Dynamic shapes not supported in MPS backend") 44 def test_qc8_linear(self): 45 for use_bias in (True, False): 46 for num_batch_dims in range(1, 3): 47 self._test_linear( 48 lambda in_size, out_size: torch.nn.Linear( 49 in_size, out_size, bias=use_bias # noqa 50 ), 51 uses_bias=use_bias, 52 quant_type="per_channel", 53 num_batch_dims=num_batch_dims, 54 ) 55 56 @unittest.skip("Dynamic shapes not supported in MPS backend") 57 def test_fp32_addmm(self): 58 """ 59 Note that the ConvertToLinear pass requires the weight matrix to be transposed. 60 """ 61 62 class AddMMModule(torch.nn.Module): 63 def __init__(self, in_size, out_size): 64 super().__init__() 65 self.mat = torch.nn.Parameter(torch.randn(in_size, out_size)) 66 self.bias = torch.nn.Parameter(torch.randn(1, out_size)) 67 68 def forward(self, x): 69 return torch.addmm(self.bias, x, self.mat) 70 71 self._test_linear( 72 lambda in_size, out_size: AddMMModule(in_size, out_size), 73 uses_bias=True, 74 ) 75 76 @unittest.skip("Dynamic shapes not supported in MPS backend") 77 def test_fp32_linear_fused_relu(self): 78 class LinearReluModule(torch.nn.Module): 79 def __init__(self, in_size, out_size, use_bias): 80 super().__init__() 81 self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias) 82 83 def forward(self, x): 84 return torch.nn.functional.relu(self.linear(x)) 85 86 for use_bias in (True, False): 87 for num_batch_dims in range(1, 3): 88 self._test_linear( 89 lambda in_size, out_size: LinearReluModule( 90 in_size, 91 out_size, 92 use_bias, # noqa 93 ), 94 uses_bias=use_bias, 95 num_batch_dims=num_batch_dims, 96 ) 97 98 @unittest.skip("Dynamic shapes not supported in MPS backend") 99 def test_qs8_linear_fused_relu(self): 100 class LinearReluModule(torch.nn.Module): 101 def __init__(self, in_size, out_size, use_bias): 102 super().__init__() 103 self.linear = torch.nn.Linear(in_size, out_size, bias=use_bias) 104 105 def forward(self, x): 106 return torch.nn.functional.relu(self.linear(x)) 107 108 for use_bias in (True, False): 109 for num_batch_dims in range(1, 3): 110 self._test_linear( 111 lambda in_size, out_size: LinearReluModule( 112 in_size, 113 out_size, 114 use_bias, # noqa 115 ), 116 num_batch_dims=num_batch_dims, 117 uses_bias=use_bias, 118 quant_type="per_tensor", 119 ) 120 121 @unittest.skip("Dynamic shapes not supported in MPS backend") 122 def test_qs8_linear(self): 123 for use_bias in (True, False): 124 for num_batch_dims in range(1, 3): 125 self._test_linear( 126 lambda in_size, out_size: torch.nn.Linear( 127 in_size, out_size, bias=use_bias # noqa 128 ), 129 uses_bias=use_bias, 130 num_batch_dims=num_batch_dims, 131 quant_type="per_tensor", 132 ) 133 134 @unittest.skip( 135 "quantized_decomposed_dequantize_per_channel_default is not supported bt MPS delegate" 136 ) 137 def test_qd8_fp32_per_token_weight_per_channel_int8(self): 138 self._run_manual_dqlinear_tests(8, torch.float) 139 140 @unittest.skip( 141 "quantized_decomposed_dequantize_per_channel_default is not supported bt MPS delegate" 142 ) 143 def test_qd8_fp32_per_token_weight_per_channel_int4(self): 144 self._run_manual_dqlinear_tests(4, torch.float) 145 146 def test_qd8_fp32_per_token_weight_per_channel_group_int4(self): 147 M_sizes = [1] 148 K_sizes = [64] 149 bl_sizes = [64] 150 N_sizes = [32] 151 152 for use_bias in [True, False]: 153 for i, _ in enumerate(M_sizes): 154 M = int(M_sizes[i]) 155 K = int(K_sizes[i]) 156 N = int(N_sizes[i]) 157 bl = int(bl_sizes[i]) 158 mod = self.ManualDQLinear( 159 input_channels=K, 160 output_channels=N, 161 weight_n_bit=4, 162 dtype=torch.float, 163 group_size=bl, 164 force_groupwise_quant=True, 165 use_bias=use_bias, 166 ) 167 168 inputs = (torch.randn(1, M, K),) 169 self._test_manual_dq_linear( 170 mod, 171 inputs, 172 weight_groupwise=True, 173 use_bias=use_bias, 174 ) 175 176 @unittest.skip("Need to fix the dq_per_channel_group output dtype") 177 def _test_qd8_fp16_per_token_weight_per_channel_group_int4(self): 178 M_sizes = [1, 2, 17, 31] 179 K_sizes = [8, 32, 64, 128] 180 bl_sizes = [8, 16, 16, 32] 181 N_sizes = [2, 17, 92, 128] 182 183 for use_bias in [True, False]: 184 for i, _ in enumerate(M_sizes): 185 M = int(M_sizes[i]) 186 K = int(K_sizes[i]) 187 N = int(N_sizes[i]) 188 bl = int(bl_sizes[i]) 189 mod = self.ManualDQLinear( 190 input_channels=K, 191 output_channels=N, 192 weight_n_bit=4, 193 dtype=torch.float16, 194 group_size=bl, 195 force_groupwise_quant=True, 196 use_bias=use_bias, 197 ) 198 199 inputs = (torch.randn(1, M, K, dtype=torch.float16),) 200 self._test_manual_dq_linear( 201 mod, 202 inputs, 203 weight_groupwise=True, 204 use_bias=use_bias, 205 atol=0.1, 206 rtol=0.1, 207 ) 208 209 def _test_linear( 210 self, 211 make_module, 212 uses_bias, 213 num_batch_dims=1, 214 quant_type=None, 215 dtype: torch.dtype = torch.float, 216 atol=1e-03, 217 ): 218 in_sizes = [3, 4, 4] 219 input_sizes = [4, 37, 17] 220 output_sizes = [4, 17, 37] 221 222 for i, _ in enumerate(in_sizes): 223 in_size = int(in_sizes[i]) 224 input_size = int(input_sizes[i]) 225 output_size = int(output_sizes[i]) 226 input_shape = [in_size] * num_batch_dims + [input_size] 227 print(f"Testing input_shape {input_shape} with {output_size} out_channels") 228 229 module = make_module(input_size, output_size).eval().to(dtype) 230 inputs = (torch.randn(input_shape).to(dtype),) 231 dynamic_shape = {} 232 for i in range(num_batch_dims): 233 dynamic_shape[i] = torch.export.Dim(f"batch{i}", min=2, max=in_size) 234 235 dynamic_shape = (dynamic_shape,) 236 print(dynamic_shape) 237 self.lower_and_test_without_partitioner( 238 module, 239 inputs, 240 func_name=inspect.stack()[0].function[5:], 241 dynamic_shapes=dynamic_shape, 242 atol=atol, 243 rtol=1e-03, 244 ) 245 246 class ManualDQLinear(torch.nn.Module): 247 def __init__( 248 self, 249 input_channels: int = 4, 250 output_channels: int = 4, 251 dtype: torch.dtype = torch.float, 252 weight_n_bit: int = 4, 253 group_size: int = 0, 254 force_groupwise_quant: bool = False, 255 use_bias: bool = False, 256 ): 257 super().__init__() 258 259 self.ic = input_channels 260 self.oc = output_channels 261 262 assert dtype in [torch.float, torch.half], "Unsupported op dtype" 263 self.op_dtype = dtype 264 265 self.group_size = self.ic if group_size == 0 else group_size 266 self.num_groups = 1 267 if self.group_size != self.ic: 268 assert self.ic % self.group_size == 0 269 assert self.group_size % 8 == 0 # TODO make this 16 270 self.num_groups = self.ic // self.group_size 271 272 assert weight_n_bit in [4, 8], "Unsupported weight_n_bit" 273 self.w_n_bit = weight_n_bit 274 self.w_quant_min, self.w_quant_max = self.get_min_max(self.w_n_bit) 275 276 self.w = torch.nn.Parameter( 277 torch.randn(self.oc, self.ic), requires_grad=False 278 ) 279 self.w_q = torch.nn.Parameter( 280 torch.zeros(self.oc, self.ic), requires_grad=False 281 ) 282 # Quantize the weights as per folded setup 283 if self.group_size != self.ic or force_groupwise_quant: 284 self.w_scales = torch.nn.Parameter( 285 torch.zeros(self.oc, self.num_groups), requires_grad=False 286 ) 287 self.w_zero_points = torch.nn.Parameter( 288 torch.zeros(self.oc, self.num_groups), requires_grad=False 289 ) 290 self.quant_weight_per_channel_group() 291 else: # per_channel quantization 292 self.w_scales = torch.nn.Parameter( 293 torch.zeros(self.oc), requires_grad=False 294 ) 295 self.w_zero_points = torch.nn.Parameter( 296 torch.zeros(self.oc), requires_grad=False 297 ) 298 self.quant_weight_per_channel() 299 300 self.bias = ( 301 torch.nn.Parameter( 302 torch.randn(self.oc).to(self.op_dtype), requires_grad=False 303 ) 304 if use_bias 305 else None 306 ) 307 308 def get_min_max(self, n_bit: int = 4): 309 max_int = 2 ** (n_bit - 1) - 1 310 min_int = -(2 ** (n_bit - 1)) 311 return min_int, max_int 312 313 def get_channel_qparams_symmetric( 314 self, 315 w: torch.Tensor, 316 n_bit: int = 4, 317 precision: torch.dtype = torch.float32, 318 ): 319 assert w.dim() == 2 320 321 to_quant = w.to(precision) 322 assert torch.isnan(to_quant).sum() == 0 323 324 max_val = to_quant.amax(dim=1, keepdim=True) 325 min_val = to_quant.amin(dim=1, keepdim=True) 326 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 327 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 328 329 min_int, max_int = self.get_min_max(n_bit) 330 331 max_val_abs = torch.max(-min_val_neg, max_val_pos) 332 scales = max_val_abs / (float(max_int - min_int) / 2) 333 scales = torch.max( 334 scales, torch.full_like(scales, torch.finfo(torch.float32).eps) 335 ) 336 zeros = torch.full_like(scales, 0) 337 return scales.to(precision).reshape(w.shape[0]), zeros.to( 338 precision 339 ).reshape(w.shape[0]).reshape(w.shape[0]) 340 341 # Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues 342 def get_group_qparams_symmetric( 343 self, w, n_bit=4, groupsize=128, precision=torch.float32 344 ): 345 # needed for GPTQ with padding 346 if groupsize > w.shape[-1]: 347 groupsize = w.shape[-1] 348 assert groupsize > 1 349 assert w.shape[-1] % groupsize == 0 350 assert w.dim() == 2 351 352 to_quant = w.reshape(-1, groupsize) 353 assert torch.isnan(to_quant).sum() == 0 354 355 max_val = to_quant.amax(dim=1, keepdim=True) 356 min_val = to_quant.amin(dim=1, keepdim=True) 357 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 358 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 359 360 max_val_abs = torch.max(-min_val_neg, max_val_pos) 361 max_int = 2 ** (n_bit - 1) - 1 362 min_int = -(2 ** (n_bit - 1)) 363 364 scales = max_val_abs / (float(max_int - min_int) / 2) 365 scales = torch.max( 366 scales, torch.full_like(scales, torch.finfo(torch.float32).eps) 367 ) 368 # TODO: make sure abs(scales) is not too small? 369 zeros = torch.full_like(scales, 0) 370 return scales.to(precision).reshape(w.shape[0], -1), zeros.to( 371 precision 372 ).reshape(w.shape[0], -1) 373 374 # Note: not using from torchao.quantization.quant_primitives because it will run into op registraion issues 375 def group_quantize_tensor_symmetric( 376 self, w, n_bit=4, group_size=128, precision=torch.float32 377 ): 378 scales, zeros = self.get_group_qparams_symmetric( 379 w, n_bit, group_size, precision 380 ) 381 n_bit = 4 382 max_int = 2 ** (n_bit - 1) - 1 383 min_int = -(2 ** (n_bit - 1)) 384 # TODO: currently we don't know how to express torch.int4, we'll 385 # add torch.int4 to core later 386 w_int8 = torch.ops.quantized_decomposed.quantize_per_channel_group( 387 w, scales, zeros, min_int, max_int, torch.int8, group_size 388 ) 389 390 return w_int8, scales, zeros 391 392 def fwd_input_per_token(self, input: torch.Tensor) -> torch.Tensor: 393 ip_quant_min = -128 394 ip_quant_max = 127 395 ( 396 ip_scales, 397 ip_zero_points, 398 ) = torch.ops.quantized_decomposed.choose_qparams_per_token_asymmetric( 399 input, torch.int8 400 ) 401 402 input = torch.ops.quantized_decomposed.quantize_per_token( 403 input, 404 ip_scales, 405 ip_zero_points, 406 ip_quant_min, 407 ip_quant_max, 408 torch.int8, 409 ) 410 input = torch.ops.quantized_decomposed.dequantize_per_token( 411 input, 412 ip_scales, 413 ip_zero_points, 414 ip_quant_min, 415 ip_quant_max, 416 torch.int8, 417 self.op_dtype, 418 ) 419 return input 420 421 def quant_weight_per_channel(self): 422 ( 423 self.w_scales.data, 424 self.w_zero_points.data, 425 ) = self.get_channel_qparams_symmetric( 426 self.w, n_bit=self.w_n_bit, precision=self.op_dtype 427 ) 428 self.w_q.data = torch.ops.quantized_decomposed.quantize_per_channel( 429 self.w, 430 self.w_scales, 431 self.w_zero_points, 432 axis=0, 433 quant_min=self.w_quant_min, 434 quant_max=self.w_quant_max, 435 dtype=torch.int8, 436 ) 437 438 def quant_weight_per_channel_group(self): 439 self.w_q.data, w, zp = self.group_quantize_tensor_symmetric( 440 self.w, 441 n_bit=self.w_n_bit, 442 group_size=self.group_size, 443 ) 444 expected_min, expected_max = self.get_min_max(self.w_n_bit) 445 assert ( 446 torch.min(self.w_q.data) >= expected_min 447 ), "Found smaller than min element in quantized weight tensor" 448 assert ( 449 torch.max(self.w_q.data) <= expected_max 450 ), "Found larger than max element in quantized weight tensor" 451 assert ( 452 w.ndim == 2 and zp.ndim == 2 453 ), f"Expecting 2d scales and zp tensors, but got {w.shape}, {zp.shape}" 454 self.w_scales.data, self.w_zero_points.data = w, zp 455 456 def fwd_weight_per_channel(self) -> torch.Tensor: 457 # This is HACKY because the dequant will produce fp32 458 return torch.ops.quantized_decomposed.dequantize_per_channel( 459 self.w_q, 460 self.w_scales, 461 self.w_zero_points, 462 axis=0, 463 quant_min=self.w_quant_min, 464 quant_max=self.w_quant_max, 465 dtype=torch.int8, # Regardless of w_n_bit, convert to 4b later 466 ) 467 468 def fwd_weight_per_channel_group(self) -> torch.Tensor: 469 return torch.ops.quantized_decomposed.dequantize_per_channel_group( 470 self.w_q, 471 self.w_scales, 472 self.w_zero_points, 473 self.w_quant_min, 474 self.w_quant_max, 475 dtype=torch.int8, # Regardless of w_n_bit, convert to 4b later 476 group_size=self.group_size, 477 output_dtype=self.op_dtype, 478 ) 479 480 def forward(self, input: torch.Tensor) -> torch.Tensor: 481 # Input 482 input = self.fwd_input_per_token(input) 483 484 # Weights 485 w = ( 486 self.fwd_weight_per_channel_group() 487 if self.w_scales.ndim == 2 488 else self.fwd_weight_per_channel() 489 ) 490 assert isinstance(w, torch.Tensor) 491 return torch.nn.functional.linear(input, w, self.bias) 492 493 def _test_manual_dq_linear( 494 self, 495 mod: torch.nn.Module, 496 inputs: Tuple[torch.Tensor], 497 weight_groupwise: bool = False, 498 use_bias: bool = False, 499 ): 500 self.lower_and_test_without_partitioner( 501 mod, inputs, func_name=inspect.stack()[0].function[5:] 502 ) 503 504 def _run_manual_dqlinear_tests(self, weight_n_bit: int, op_dtype: torch.dtype): 505 in_sizes = [1, 4, 4] 506 input_sizes = [4, 37, 17] 507 output_sizes = [4, 17, 37] 508 509 for use_bias in [True, False]: 510 for i, _ in enumerate(in_sizes): 511 in_size = int(in_sizes[i]) 512 input_size = int(input_sizes[i]) 513 output_size = int(output_sizes[i]) 514 mod = self.ManualDQLinear( 515 input_channels=input_size, 516 output_channels=output_size, 517 weight_n_bit=weight_n_bit, 518 dtype=op_dtype, 519 use_bias=use_bias, 520 ) 521 522 inputs = (torch.randn(1, in_size, input_size).to(op_dtype),) 523 self._test_manual_dq_linear(mod, inputs, use_bias=use_bias) 524