1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import re 9from functools import partial 10from pathlib import Path 11from typing import Any, Dict, Optional 12 13import torch 14import torch.nn as nn 15import torch.nn.functional as F 16 17from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer 18 19from executorch.extension.llm.export.builder import DType 20 21from sentencepiece import SentencePieceProcessor 22 23try: 24 from fairseq2.nn.embedding import ( 25 Embedding as fsEmbedding, 26 StandardEmbedding as fsStandardEmbedding, 27 ) 28 29 from fairseq2.nn.projection import Linear as fsLinear 30 31 print("Using fairseq2 modules.") 32except: 33 fsEmbedding = nn.Embedding 34 fsStandardEmbedding = nn.Embedding 35 fsLinear = nn.Linear 36 37 38def quantize( # noqa C901 39 model: torch.nn.Module, 40 qmode: str, 41 activation_dtype: Optional[DType], 42 checkpoint_path: Optional[Path] = None, 43 # following arguments only available when setting int4 or gptq quantization. 44 group_size: Optional[int] = 128, 45 # following arguments are only used for GPTQ 46 calibration_tasks: Optional[list] = None, 47 calibration_limit: Optional[int] = None, 48 calibration_seq_length: Optional[int] = None, 49 pad_calibration_inputs: bool = False, 50 percdamp: float = 0.01, 51 blocksize: int = 128, 52 tokenizer_path: Optional[Path] = None, 53 verbose: bool = False, 54) -> torch.nn.Module: 55 """ 56 Quantizes a model by converting all weights to int8. 57 Args: 58 model: A model to quantize. 59 qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq 60 Returns: 61 A quantized model. 62 """ 63 if activation_dtype is not None: 64 torch_dtype = activation_dtype.to_torch_dtype() 65 else: 66 torch_dtype = torch.float16 67 68 assert checkpoint_path, "Need to specify a checkpoint" 69 # if checkpoint_path is None: 70 # checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") 71 72 if qmode == "int8": 73 # Add quantization mode options here: group size, bit width, etc. 74 return WeightOnlyInt8QuantHandler(model).quantized_model() 75 elif qmode.startswith("torchao:"): 76 pattern = r"torchao:8da(\d+)w" 77 matches = re.findall(pattern, qmode) 78 assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}" 79 bitwidth = int(matches[0][0]) 80 _load_torchao_ops_aten() 81 from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer 82 83 with torch.no_grad(): 84 model = Int8DynActIntxWeightLinearQuantizer( 85 device="cpu", 86 precision=torch.float32, 87 groupsize=group_size, 88 bitwidth=bitwidth, 89 has_weight_zeros=False, 90 ).quantize(model) 91 92 if verbose: 93 print("quantized model:", model) 94 return model 95 elif qmode == "8da4w": 96 # Check for required args 97 if group_size is None: 98 raise Exception("For 8da4w quantization, group size must be specified.") 99 from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer 100 101 model = Int8DynActInt4WeightQuantizer( 102 precision=torch_dtype, groupsize=group_size 103 ).quantize(model) 104 105 if verbose: 106 print("quantized model:", model) 107 return model 108 elif qmode == "8da4w-gptq": 109 # Check for required args 110 required_args: Optional[Any] = [ 111 group_size, 112 calibration_limit, 113 calibration_seq_length, 114 ] 115 if any(arg is None for arg in required_args): 116 raise Exception( 117 "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified." 118 ) 119 if calibration_tasks is None: 120 calibration_tasks = ["wikitext"] 121 122 try: 123 # torchao 0.3+ 124 from torchao._eval import InputRecorder # pyre-fixme[21] 125 except ImportError: 126 from torchao.quantization.GPTQ import InputRecorder # pyre-ignore 127 128 from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer 129 130 if tokenizer_path is None: 131 tokenizer_path = checkpoint_path.parent / "tokenizer.model" 132 assert tokenizer_path.is_file(), tokenizer_path 133 tokenizer = SentencePieceProcessor( # pyre-ignore[28] 134 model_file=str(tokenizer_path) 135 ) 136 137 inputs = ( 138 InputRecorder( # pyre-fixme[16] 139 tokenizer, 140 calibration_seq_length, 141 None, # input_prep_func 142 pad_calibration_inputs, 143 model.vocab_size, 144 ) 145 .record_inputs( 146 calibration_tasks, 147 calibration_limit, 148 ) 149 .get_inputs() 150 ) 151 152 gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer( 153 blocksize, 154 percdamp, 155 group_size, 156 ) 157 model = gptq_quantizer.quantize(model, inputs) 158 return model 159 elif qmode == "vulkan_4w": 160 q_group_size = 256 if group_size is None else group_size 161 model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model) 162 163 # Apply additional quantizer for linear layers that aren't lowered to Vulkan 164 # at the moment 165 from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer 166 167 model = Int8DynActInt4WeightQuantizer( 168 precision=torch_dtype, groupsize=q_group_size 169 ).quantize(model) 170 171 return model 172 else: 173 raise Exception(f"Unrecognized quantize mode: {qmode}") 174 175 176def dynamically_quantize_per_channel( 177 x, 178 quant_min, 179 quant_max, 180 target_dtype, 181 group_size: Optional[int] = None, 182 *, 183 scales_dtype=torch.float16, 184 enable_non_multiple_groups=True, 185): 186 """ 187 Dynamically quantize per channel. This function is used for quantizing weights, 188 for linear and embedding layers. 189 190 Arguments: 191 x: input tensor, 192 quant_min: minimum value after quantization, 193 quant_max: maximum value after quantization, 194 target_dtype: target data type for weights after quantization, 195 group_size: number of elements of the channel to quantize together 196 197 Keyword arguments: 198 scales_dtype: data type of scale, 199 enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size, 200 with a final group of a size less than group size. 201 202 Assumptions: 203 This function assumes symmetric quantization, axis ==0 and a dense memory format. 204 """ 205 206 # assumes symmetric quantization 207 # assumes axis == 0 208 # assumes dense memory format 209 # TODO(future): relax ^ as needed 210 211 x_shape_1 = x.shape[1] 212 213 if group_size is None or group_size == 0: 214 items = x_shape_1 215 elif ((x_shape_1 % group_size) == 0) or not enable_non_multiple_groups: 216 assert group_size > 0, "group size must be positive" 217 assert ( 218 x_shape_1 % group_size 219 ) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {group_size}" 220 items = group_size 221 else: 222 assert group_size > 0, "group size must be positive" 223 print( 224 f"row-size of weight matrix {x_shape_1} is not divisible by group size {group_size}, using nearest neighbor rounding" 225 ) 226 assert ( 227 x_shape_1 % group_size != 0 228 ), f"expected x.shape[1] to not be a multiple of group size {group_size}, but got {x_shape_1}" 229 padding = group_size - (x_shape_1 % group_size) 230 x = F.pad(x, (0, padding)) 231 items = group_size 232 233 # default setup for affine quantization of activations 234 eps = torch.finfo(torch.float32).eps 235 236 x = x.view(x.shape[0], x.shape[1] // items, items) 237 # get min and max 238 min_val, max_val = torch.aminmax(x, dim=2) 239 # print(f"min_val {min_val}") 240 # print(f"max_val {max_val}") 241 242 # calculate scales and zero_points based on min and max 243 # reference: https://fburl.com/code/srbiybme 244 min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) 245 max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) 246 device = min_val_neg.device 247 248 # reference: https://fburl.com/code/4wll53rk 249 max_val_pos = torch.max(-min_val_neg, max_val_pos) 250 scales = max_val_pos / (float(quant_max - quant_min) / 2) 251 # ensure scales is the same dtype as the original tensor 252 scales = torch.clamp(scales, min=eps).to(x.dtype) 253 zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) 254 255 # quantize based on qmin/qmax/scales/zp 256 # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63 257 x_div = x / scales.unsqueeze(-1) 258 x_round = torch.round(x_div) 259 x_zp = x_round + zero_points.unsqueeze(-1) 260 quant = ( 261 torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1) 262 ) 263 264 scales = scales.to(dtype=scales_dtype) 265 quant = quant[:, :x_shape_1] 266 267 return quant, scales, zero_points 268 269 270######################################################################### 271### QuantHandler API definition ### 272 273 274class QuantHandler: 275 def __init__(self, mod): 276 self.mod = mod 277 278 def create_quantized_state_dict(self) -> Dict: # "StateDict" 279 pass 280 281 def convert_for_runtime(self) -> nn.Module: 282 pass 283 284 def quantized_model(self) -> nn.Module: 285 model_updated_state_dict = self.create_quantized_state_dict() 286 self.convert_for_runtime() 287 self.mod.load_state_dict(model_updated_state_dict) 288 return self.mod 289 290 291######################################################################### 292### Weight-only int8 per-channel quantized code ### 293 294 295def replace_linear_weight_only_int8_per_channel(module, node_type): 296 for name, child in module.named_children(): 297 # print(f"name: {name}") 298 if isinstance(child, nn.Linear): 299 if ( 300 (node_type == "*") 301 or (node_type == "output" and name == "output") 302 or (node_type == "!output" and name != "output") 303 ): 304 # print(f"{name, child}") 305 # print(f"in_features: {child.in_features}") 306 # print(f"out_features: {child.out_features}") 307 setattr( 308 module, 309 name, 310 WeightOnlyInt8Linear("cpu", child.in_features, child.out_features), 311 ) 312 else: 313 replace_linear_weight_only_int8_per_channel(child, node_type) 314 315 316class WeightOnlyInt8QuantHandler(QuantHandler): 317 def __init__( 318 self, 319 mod, 320 device="cpu", 321 *, 322 node_type: str = "*", 323 bitwidth: Optional[int] = None, 324 group_size: Optional[int] = None, 325 ): 326 self.mod = mod 327 self.group_size = group_size 328 self.node_type = node_type 329 if bitwidth is None: 330 self.bitwidth = 8 331 else: 332 self.bitwidth = bitwidth 333 334 @torch.no_grad() 335 def create_quantized_state_dict(self) -> Dict: 336 cur_state_dict = self.mod.state_dict() 337 338 if self.bitwidth == 4: 339 range_min = -8 340 range_max = 7 341 elif self.bitwidth == 8: 342 range_min = -128 343 range_max = 127 344 else: 345 raise ValueError(f"Unsupported bitwidth {self.bitwidth}") 346 347 for fqn, mod in self.mod.named_modules(): 348 # print(f"maybe? quantize {fqn}...{type(mod)}") 349 if isinstance(mod, torch.nn.Linear) or isinstance(mod, fsLinear): 350 # print(f"candidate {fqn}, nodetype {self.node_type}") 351 if ( 352 (self.node_type == "*") 353 or (self.node_type == "output" and fqn in ["output", "final_proj"]) 354 or ( 355 self.node_type == "!output" 356 and fqn not in ["output", "final_proj"] 357 ) 358 ): 359 print( 360 f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" 361 ) 362 363 # print(f"initial weight shape {mod.weight.shape}") 364 input_weight = mod.weight.float() 365 366 # print(f"expanded weight shape {input_weight.shape}") 367 weight, scales, _ = dynamically_quantize_per_channel( 368 input_weight, 369 range_min, 370 range_max, 371 torch.int8, 372 self.group_size, 373 scales_dtype=mod.weight.dtype, 374 ) 375 376 cur_state_dict[f"{fqn}.weight"] = weight 377 # squeeze makes group_size=rowsize unidimensional 378 cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1) 379 380 return cur_state_dict 381 382 def convert_for_runtime(self) -> nn.Module: 383 replace_linear_weight_only_int8_per_channel(self.mod, self.node_type) 384 return self.mod 385 386 def quantized_model(self) -> nn.Module: 387 model_updated_state_dict = self.create_quantized_state_dict() 388 self.convert_for_runtime() 389 self.mod.load_state_dict(model_updated_state_dict) 390 return self.mod 391 392 393class WeightOnlyInt8Linear(torch.nn.Module): 394 __constants__ = ["in_features", "out_features"] 395 in_features: int 396 out_features: int 397 weight: torch.Tensor 398 399 def __init__( 400 self, 401 device, 402 in_features: int, 403 out_features: int, 404 bias: bool = True, 405 dtype=None, 406 ) -> None: 407 super().__init__() 408 self.in_features = in_features 409 self.out_features = out_features 410 self.register_buffer( 411 "weight", torch.zeros((out_features, in_features), dtype=torch.int8) 412 ) 413 self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16)) 414 415 def forward(self, input: torch.Tensor) -> torch.Tensor: 416 return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales 417 # return F.linear(input, self.weight.to(dtype=input.dtype)) * se... 418 419 420def linear_forward_8da8w( 421 x, 422 weight_int8, 423 scales, 424 zeros, 425 out_features, 426 precision, 427): 428 from torchao.quantization.utils import per_token_dynamic_quant 429 430 x = per_token_dynamic_quant(x) 431 n_bit = 8 432 quant_min = -(2 ** (n_bit - 1)) 433 quant_max = 2 ** (n_bit - 1) - 1 434 w_dq = torch.ops.quantized_decomposed.dequantize_per_channel( 435 weight_int8, 436 scales, 437 zeros, 438 0, 439 quant_min, 440 quant_max, 441 torch.int8, 442 out_dtype=precision, 443 ) 444 c = torch.nn.functional.linear(x, w_dq) 445 446 return c 447 448 449class Int8DynActInt8WeightLinear(torch.nn.Module): 450 __constants__ = ["in_features", "out_features"] 451 452 in_features: int 453 out_features: int 454 weight: torch.Tensor 455 456 """ 457 This module implements a dynamic quantized linear layer with int8 weight. 458 Weights are per channel quantized. Parameters of importance 459 precision: precision of input and output. e.g. torch.float32 means input 460 activation is float32 and output is float32. 461 """ 462 463 def __init__( 464 self, 465 in_features: int, 466 out_features: int, 467 bias=True, 468 device=None, 469 dtype=None, 470 precision: torch.dtype = torch.float32, 471 ) -> None: 472 super().__init__() 473 self.in_features = in_features 474 self.out_features = out_features 475 assert not bias, "require bias=False" 476 self.precision = precision 477 478 if dtype is not None: 479 raise ValueError("Please specify 'precision' instead of 'dtype'") 480 481 # currently storing unpacked int8 weights 482 self.register_buffer( 483 "weight", 484 torch.zeros((out_features, in_features), dtype=torch.int8), 485 ) 486 self.register_buffer( 487 "scales", 488 torch.zeros( 489 (out_features), 490 dtype=torch.float32, 491 ), 492 ) 493 self.register_buffer( 494 "zeros", 495 torch.zeros( 496 (out_features), 497 dtype=torch.float32, 498 ), 499 ) 500 501 def forward(self, input: torch.Tensor) -> torch.Tensor: 502 input = input.to(self.precision) 503 return linear_forward_8da8w( 504 input, 505 self.weight, 506 self.scales, 507 self.zeros, 508 self.out_features, 509 self.precision, 510 ) 511 512 513######################################################################### 514##### embedding table quantization ###### 515 516 517def replace_embedding_weight_only_grouped_int8_per_channel( 518 module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False 519): 520 for name, child in module.named_children(): 521 # print(f"name: {name}") 522 if isinstance(child, nn.Embedding): 523 # print(f"{name, child}") 524 # print(f"weights size: {child.weight.size()}") 525 setattr( 526 module, 527 name, 528 QuantizedGroupEmbedding( 529 device=device, 530 vocab_size=child.weight.shape[0], 531 embedding_dim=child.weight.shape[1], 532 group_size=group_size, 533 dtype=child.weight.dtype, 534 packed=packed, 535 bitwidth=bitwidth, 536 ), 537 ) 538 else: 539 replace_embedding_weight_only_grouped_int8_per_channel( 540 child, device, bitwidth, group_size, packed 541 ) 542 543 544class EmbeddingQuantHandler(QuantHandler): 545 def __init__( 546 self, 547 mod, 548 device="cpu", 549 *, 550 bitwidth: int = 8, 551 group_size: Optional[int] = None, 552 packed=False, 553 ): 554 if isinstance(packed, str): 555 packed = packed == "True" 556 self.mod = mod 557 self.device = device 558 self.group_size = group_size 559 self.bitwidth = bitwidth 560 self.packed = packed 561 if (bitwidth not in [2, 4]) and packed: 562 raise RuntimeError("pack only works with bitsize 2, 4") 563 564 @torch.no_grad() 565 def create_quantized_state_dict(self, packed=False) -> Dict: 566 cur_state_dict = self.mod.state_dict() 567 568 if self.bitwidth == 2: 569 range_min = -2 570 range_max = 1 571 elif self.bitwidth == 4: 572 range_min = -8 573 range_max = 7 574 elif self.bitwidth == 8: 575 range_min = -128 576 range_max = 127 577 else: 578 raise ValueError(f"Unsupported bitwidth {self.bitwidth}") 579 580 for fqn, mod in self.mod.named_modules(): 581 if isinstance(mod, nn.Embedding): 582 # print("****") 583 # print(f"Embedding identified: {fqn, mod}") 584 # print(f"weights size: {mod.weight.size()}") 585 # print(f"quantize {fqn}...") 586 587 print( 588 f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}" 589 ) 590 weight, scales, _ = dynamically_quantize_per_channel( 591 mod.weight.float(), 592 range_min, 593 range_max, 594 torch.int8, 595 self.group_size, 596 scales_dtype=mod.weight.dtype, 597 ) 598 599 if packed: 600 if self.bitwidth == 2: 601 if weight.shape[-1] % 4 != 0: 602 raise RuntimeError("automatic padding not implemented yet") 603 weight_range_shifted = weight.add(2).view(torch.uint8) 604 weight_view = weight_range_shifted.view( 605 weight.shape[0], weight.shape[1] // 4, 4 606 ) 607 weight_0 = weight_view[:, :, 0] 608 weight_1 = weight_view[:, :, 1] << 2 609 weight_2 = weight_view[:, :, 2] << 4 610 weight_3 = weight_view[:, :, 3] << 6 611 weight_packed = weight_0 + weight_1 + weight_2 + weight_3 612 weight = weight_packed 613 elif self.bitwidth == 4: 614 if weight.shape[-1] % 2 != 0: 615 raise RuntimeError("automatic padding not implemented yet") 616 weight_range_shifted = weight.add(8).view(torch.uint8) 617 weight_view = weight_range_shifted.view( 618 weight.shape[0], weight.shape[1] // 2, 2 619 ) 620 weight_even = weight_view[:, :, 0] * 16 # left shift 4 621 weight_odd = weight_view[:, :, 1] 622 weight_packed = weight_even + weight_odd 623 weight = weight_packed 624 625 weight = weight.to(device=self.device) 626 scales = scales.to(device=self.device) 627 # Update state dict 628 cur_state_dict[f"{fqn}.weight"] = weight 629 # squeeze makes group_size=rowsize unidimensional 630 cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1) 631 632 return cur_state_dict 633 634 def convert_for_runtime(self) -> nn.Module: 635 replace_embedding_weight_only_grouped_int8_per_channel( 636 self.mod, self.device, self.bitwidth, self.group_size, self.packed 637 ) 638 return self.mod 639 640 def quantized_model(self) -> nn.Module: 641 model_updated_state_dict = self.create_quantized_state_dict(self.packed) 642 self.convert_for_runtime() 643 self.mod.load_state_dict(model_updated_state_dict) 644 return self.mod 645 646 647class QuantizedGroupEmbedding(torch.nn.Module): 648 def __init__( 649 self, 650 device, 651 vocab_size: int, 652 embedding_dim: int, 653 group_size: Optional[int] = None, 654 dtype=torch.half, 655 packed=False, 656 bitwidth: int = 8, 657 ) -> None: 658 super().__init__() 659 if group_size is None or group_size == 0: 660 group_size = embedding_dim 661 self.group_size = group_size 662 self.dtype = dtype 663 self.packed = packed 664 self.bitwidth = bitwidth 665 if not packed: 666 self.register_buffer( 667 "weight", 668 torch.zeros( 669 (vocab_size, embedding_dim), dtype=torch.int8, device=device 670 ), 671 ) 672 else: # packed 673 if bitwidth == 2: 674 self.register_buffer( 675 "weight", 676 torch.zeros( 677 (vocab_size, embedding_dim // 4), 678 dtype=torch.uint8, 679 device=device, 680 ), 681 ) 682 elif bitwidth == 4: 683 self.register_buffer( 684 "weight", 685 torch.zeros( 686 (vocab_size, embedding_dim // 2), 687 dtype=torch.uint8, 688 device=device, 689 ), 690 ) 691 692 groups_per_row = (embedding_dim + group_size - 1) // group_size 693 if groups_per_row > 1: 694 self.register_buffer( 695 "scales", 696 torch.ones( 697 (vocab_size, groups_per_row), dtype=torch.float16, device=device 698 ), 699 ) 700 else: 701 self.register_buffer( 702 "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device) 703 ) 704 705 @torch.no_grad() 706 def forward(self, indices: torch.Tensor) -> torch.Tensor: 707 if not self.packed: # 8bit 708 return torch.ops.quantized_decomposed.embedding_byte.dtype( 709 self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype 710 ) 711 else: # packed 712 if self.bitwidth == 2: 713 return torch.ops.quantized_decomposed.embedding_2bit.dtype( 714 self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype 715 ) 716 717 # Remaining case (always return to make pyre happy) 718 assert self.bitwidth == 4 719 return torch.ops.quantized_decomposed.embedding_4bit.dtype( 720 self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype 721 ) 722 723 724############################ Source Transform Start ####################### 725 726 727def get_quant_embedding_transform(args): 728 if args.embedding_quantize.startswith("torchao:"): 729 bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",") 730 group_size = int(group_size) 731 bitwidth = int(bitwidth) 732 _load_torchao_ops_aten() 733 from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer 734 735 def _torchao_embedding_quantizer(model): 736 with torch.no_grad(): 737 model = IntxWeightEmbeddingQuantizer( 738 device="cpu", 739 precision=torch.float32, 740 bitwidth=bitwidth, 741 groupsize=group_size, 742 ).quantize(model) 743 return model 744 745 return _torchao_embedding_quantizer 746 747 bitwidth, group_size = args.embedding_quantize.split(",") 748 if group_size == "none" or group_size == "None" or group_size == "0": 749 group_size = None 750 else: 751 group_size = int(group_size) 752 bitwidth = int(bitwidth) 753 return lambda model: EmbeddingQuantHandler( 754 model, 755 bitwidth=bitwidth, 756 group_size=group_size, 757 packed=(bitwidth in [2, 4]), 758 ).quantized_model() 759 760 761def get_quant_weight_transform(args, dtype_override, verbose): 762 # If these optional args are None, don't provide them to quantize() 763 quant_args_str = [ 764 "group_size", 765 "calibration_tasks", 766 "calibration_limit", 767 "calibration_seq_length", 768 ] 769 arg_dict = vars(args) 770 quant_args = { 771 param: val 772 for param in quant_args_str 773 if (val := arg_dict.get(param)) is not None 774 } 775 776 return partial( 777 quantize, 778 **quant_args, 779 qmode=args.quantization_mode, 780 activation_dtype=dtype_override, 781 checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None), 782 tokenizer_path=( 783 Path(path) if (path := args.tokenizer_path) is not None else None 784 ), 785 ) 786 787 788def _load_torchao_ops_aten(): 789 import glob 790 import os 791 792 libs = glob.glob( 793 os.path.abspath( 794 os.path.join( 795 os.environ.get("CMAKE_INSTALL_PREFIX", ""), 796 "lib/libtorchao_ops_aten.*", 797 ) 798 ) 799 ) 800 assert ( 801 len(libs) == 1 802 ), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly." 803 logging.info(f"Loading custom ops library: {libs[0]}") 804 torch.ops.load_library(libs[0]) 805 806 807############################ Source Transform End ####################### 808