xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/quantize.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import re
9from functools import partial
10from pathlib import Path
11from typing import Any, Dict, Optional
12
13import torch
14import torch.nn as nn
15import torch.nn.functional as F
16
17from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer
18
19from executorch.extension.llm.export.builder import DType
20
21from sentencepiece import SentencePieceProcessor
22
23try:
24    from fairseq2.nn.embedding import (
25        Embedding as fsEmbedding,
26        StandardEmbedding as fsStandardEmbedding,
27    )
28
29    from fairseq2.nn.projection import Linear as fsLinear
30
31    print("Using fairseq2 modules.")
32except:
33    fsEmbedding = nn.Embedding
34    fsStandardEmbedding = nn.Embedding
35    fsLinear = nn.Linear
36
37
38def quantize(  # noqa C901
39    model: torch.nn.Module,
40    qmode: str,
41    activation_dtype: Optional[DType],
42    checkpoint_path: Optional[Path] = None,
43    # following arguments only available when setting int4 or gptq quantization.
44    group_size: Optional[int] = 128,
45    # following arguments are only used for GPTQ
46    calibration_tasks: Optional[list] = None,
47    calibration_limit: Optional[int] = None,
48    calibration_seq_length: Optional[int] = None,
49    pad_calibration_inputs: bool = False,
50    percdamp: float = 0.01,
51    blocksize: int = 128,
52    tokenizer_path: Optional[Path] = None,
53    verbose: bool = False,
54) -> torch.nn.Module:
55    """
56    Quantizes a model by converting all weights to int8.
57    Args:
58        model: A model to quantize.
59        qmode: quantization mode, e.g. int8, 8da4w, 8da4w-gptq
60    Returns:
61        A quantized model.
62    """
63    if activation_dtype is not None:
64        torch_dtype = activation_dtype.to_torch_dtype()
65    else:
66        torch_dtype = torch.float16
67
68    assert checkpoint_path, "Need to specify a checkpoint"
69    # if checkpoint_path is None:
70    #     checkpoint_path = Path("checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
71
72    if qmode == "int8":
73        # Add quantization mode options here: group size, bit width, etc.
74        return WeightOnlyInt8QuantHandler(model).quantized_model()
75    elif qmode.startswith("torchao:"):
76        pattern = r"torchao:8da(\d+)w"
77        matches = re.findall(pattern, qmode)
78        assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
79        bitwidth = int(matches[0][0])
80        _load_torchao_ops_aten()
81        from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer
82
83        with torch.no_grad():
84            model = Int8DynActIntxWeightLinearQuantizer(
85                device="cpu",
86                precision=torch.float32,
87                groupsize=group_size,
88                bitwidth=bitwidth,
89                has_weight_zeros=False,
90            ).quantize(model)
91
92        if verbose:
93            print("quantized model:", model)
94        return model
95    elif qmode == "8da4w":
96        # Check for required args
97        if group_size is None:
98            raise Exception("For 8da4w quantization, group size must be specified.")
99        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
100
101        model = Int8DynActInt4WeightQuantizer(
102            precision=torch_dtype, groupsize=group_size
103        ).quantize(model)
104
105        if verbose:
106            print("quantized model:", model)
107        return model
108    elif qmode == "8da4w-gptq":
109        # Check for required args
110        required_args: Optional[Any] = [
111            group_size,
112            calibration_limit,
113            calibration_seq_length,
114        ]
115        if any(arg is None for arg in required_args):
116            raise Exception(
117                "For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
118            )
119        if calibration_tasks is None:
120            calibration_tasks = ["wikitext"]
121
122        try:
123            # torchao 0.3+
124            from torchao._eval import InputRecorder  # pyre-fixme[21]
125        except ImportError:
126            from torchao.quantization.GPTQ import InputRecorder  # pyre-ignore
127
128        from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer
129
130        if tokenizer_path is None:
131            tokenizer_path = checkpoint_path.parent / "tokenizer.model"
132        assert tokenizer_path.is_file(), tokenizer_path
133        tokenizer = SentencePieceProcessor(  # pyre-ignore[28]
134            model_file=str(tokenizer_path)
135        )
136
137        inputs = (
138            InputRecorder(  # pyre-fixme[16]
139                tokenizer,
140                calibration_seq_length,
141                None,  # input_prep_func
142                pad_calibration_inputs,
143                model.vocab_size,
144            )
145            .record_inputs(
146                calibration_tasks,
147                calibration_limit,
148            )
149            .get_inputs()
150        )
151
152        gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
153            blocksize,
154            percdamp,
155            group_size,
156        )
157        model = gptq_quantizer.quantize(model, inputs)
158        return model
159    elif qmode == "vulkan_4w":
160        q_group_size = 256 if group_size is None else group_size
161        model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model)
162
163        # Apply additional quantizer for linear layers that aren't lowered to Vulkan
164        # at the moment
165        from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
166
167        model = Int8DynActInt4WeightQuantizer(
168            precision=torch_dtype, groupsize=q_group_size
169        ).quantize(model)
170
171        return model
172    else:
173        raise Exception(f"Unrecognized quantize mode: {qmode}")
174
175
176def dynamically_quantize_per_channel(
177    x,
178    quant_min,
179    quant_max,
180    target_dtype,
181    group_size: Optional[int] = None,
182    *,
183    scales_dtype=torch.float16,
184    enable_non_multiple_groups=True,
185):
186    """
187    Dynamically quantize per channel.  This function is used for quantizing weights,
188    for linear and embedding layers.
189
190    Arguments:
191        x: input tensor,
192        quant_min: minimum value after quantization,
193        quant_max: maximum value after quantization,
194        target_dtype: target data type for weights after quantization,
195        group_size: number of elements of the channel to quantize together
196
197    Keyword arguments:
198        scales_dtype: data type of scale,
199        enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
200                        with a final group of a size less than group size.
201
202    Assumptions:
203        This function assumes symmetric quantization, axis ==0 and a dense memory format.
204    """
205
206    # assumes symmetric quantization
207    # assumes axis == 0
208    # assumes dense memory format
209    # TODO(future): relax ^ as needed
210
211    x_shape_1 = x.shape[1]
212
213    if group_size is None or group_size == 0:
214        items = x_shape_1
215    elif ((x_shape_1 % group_size) == 0) or not enable_non_multiple_groups:
216        assert group_size > 0, "group size must be positive"
217        assert (
218            x_shape_1 % group_size
219        ) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {group_size}"
220        items = group_size
221    else:
222        assert group_size > 0, "group size must be positive"
223        print(
224            f"row-size of weight matrix {x_shape_1} is not divisible by group size {group_size}, using nearest neighbor rounding"
225        )
226        assert (
227            x_shape_1 % group_size != 0
228        ), f"expected x.shape[1] to not be a multiple of group size {group_size}, but got {x_shape_1}"
229        padding = group_size - (x_shape_1 % group_size)
230        x = F.pad(x, (0, padding))
231        items = group_size
232
233    # default setup for affine quantization of activations
234    eps = torch.finfo(torch.float32).eps
235
236    x = x.view(x.shape[0], x.shape[1] // items, items)
237    # get min and max
238    min_val, max_val = torch.aminmax(x, dim=2)
239    # print(f"min_val {min_val}")
240    # print(f"max_val {max_val}")
241
242    # calculate scales and zero_points based on min and max
243    # reference: https://fburl.com/code/srbiybme
244    min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
245    max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
246    device = min_val_neg.device
247
248    # reference: https://fburl.com/code/4wll53rk
249    max_val_pos = torch.max(-min_val_neg, max_val_pos)
250    scales = max_val_pos / (float(quant_max - quant_min) / 2)
251    # ensure scales is the same dtype as the original tensor
252    scales = torch.clamp(scales, min=eps).to(x.dtype)
253    zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
254
255    # quantize based on qmin/qmax/scales/zp
256    # reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
257    x_div = x / scales.unsqueeze(-1)
258    x_round = torch.round(x_div)
259    x_zp = x_round + zero_points.unsqueeze(-1)
260    quant = (
261        torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
262    )
263
264    scales = scales.to(dtype=scales_dtype)
265    quant = quant[:, :x_shape_1]
266
267    return quant, scales, zero_points
268
269
270#########################################################################
271###                QuantHandler API definition                        ###
272
273
274class QuantHandler:
275    def __init__(self, mod):
276        self.mod = mod
277
278    def create_quantized_state_dict(self) -> Dict:  # "StateDict"
279        pass
280
281    def convert_for_runtime(self) -> nn.Module:
282        pass
283
284    def quantized_model(self) -> nn.Module:
285        model_updated_state_dict = self.create_quantized_state_dict()
286        self.convert_for_runtime()
287        self.mod.load_state_dict(model_updated_state_dict)
288        return self.mod
289
290
291#########################################################################
292###             Weight-only int8 per-channel quantized code           ###
293
294
295def replace_linear_weight_only_int8_per_channel(module, node_type):
296    for name, child in module.named_children():
297        # print(f"name: {name}")
298        if isinstance(child, nn.Linear):
299            if (
300                (node_type == "*")
301                or (node_type == "output" and name == "output")
302                or (node_type == "!output" and name != "output")
303            ):
304                # print(f"{name, child}")
305                # print(f"in_features: {child.in_features}")
306                # print(f"out_features: {child.out_features}")
307                setattr(
308                    module,
309                    name,
310                    WeightOnlyInt8Linear("cpu", child.in_features, child.out_features),
311                )
312        else:
313            replace_linear_weight_only_int8_per_channel(child, node_type)
314
315
316class WeightOnlyInt8QuantHandler(QuantHandler):
317    def __init__(
318        self,
319        mod,
320        device="cpu",
321        *,
322        node_type: str = "*",
323        bitwidth: Optional[int] = None,
324        group_size: Optional[int] = None,
325    ):
326        self.mod = mod
327        self.group_size = group_size
328        self.node_type = node_type
329        if bitwidth is None:
330            self.bitwidth = 8
331        else:
332            self.bitwidth = bitwidth
333
334    @torch.no_grad()
335    def create_quantized_state_dict(self) -> Dict:
336        cur_state_dict = self.mod.state_dict()
337
338        if self.bitwidth == 4:
339            range_min = -8
340            range_max = 7
341        elif self.bitwidth == 8:
342            range_min = -128
343            range_max = 127
344        else:
345            raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
346
347        for fqn, mod in self.mod.named_modules():
348            # print(f"maybe? quantize {fqn}...{type(mod)}")
349            if isinstance(mod, torch.nn.Linear) or isinstance(mod, fsLinear):
350                # print(f"candidate {fqn}, nodetype {self.node_type}")
351                if (
352                    (self.node_type == "*")
353                    or (self.node_type == "output" and fqn in ["output", "final_proj"])
354                    or (
355                        self.node_type == "!output"
356                        and fqn not in ["output", "final_proj"]
357                    )
358                ):
359                    print(
360                        f"quantize {self.node_type} {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
361                    )
362
363                    # print(f"initial weight shape {mod.weight.shape}")
364                    input_weight = mod.weight.float()
365
366                    # print(f"expanded weight shape {input_weight.shape}")
367                    weight, scales, _ = dynamically_quantize_per_channel(
368                        input_weight,
369                        range_min,
370                        range_max,
371                        torch.int8,
372                        self.group_size,
373                        scales_dtype=mod.weight.dtype,
374                    )
375
376                    cur_state_dict[f"{fqn}.weight"] = weight
377                    # squeeze makes group_size=rowsize unidimensional
378                    cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
379
380        return cur_state_dict
381
382    def convert_for_runtime(self) -> nn.Module:
383        replace_linear_weight_only_int8_per_channel(self.mod, self.node_type)
384        return self.mod
385
386    def quantized_model(self) -> nn.Module:
387        model_updated_state_dict = self.create_quantized_state_dict()
388        self.convert_for_runtime()
389        self.mod.load_state_dict(model_updated_state_dict)
390        return self.mod
391
392
393class WeightOnlyInt8Linear(torch.nn.Module):
394    __constants__ = ["in_features", "out_features"]
395    in_features: int
396    out_features: int
397    weight: torch.Tensor
398
399    def __init__(
400        self,
401        device,
402        in_features: int,
403        out_features: int,
404        bias: bool = True,
405        dtype=None,
406    ) -> None:
407        super().__init__()
408        self.in_features = in_features
409        self.out_features = out_features
410        self.register_buffer(
411            "weight", torch.zeros((out_features, in_features), dtype=torch.int8)
412        )
413        self.register_buffer("scales", torch.ones(out_features, dtype=torch.bfloat16))
414
415    def forward(self, input: torch.Tensor) -> torch.Tensor:
416        return F.linear(input, self.weight.to(dtype=input.dtype)) * self.scales
417        # return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
418
419
420def linear_forward_8da8w(
421    x,
422    weight_int8,
423    scales,
424    zeros,
425    out_features,
426    precision,
427):
428    from torchao.quantization.utils import per_token_dynamic_quant
429
430    x = per_token_dynamic_quant(x)
431    n_bit = 8
432    quant_min = -(2 ** (n_bit - 1))
433    quant_max = 2 ** (n_bit - 1) - 1
434    w_dq = torch.ops.quantized_decomposed.dequantize_per_channel(
435        weight_int8,
436        scales,
437        zeros,
438        0,
439        quant_min,
440        quant_max,
441        torch.int8,
442        out_dtype=precision,
443    )
444    c = torch.nn.functional.linear(x, w_dq)
445
446    return c
447
448
449class Int8DynActInt8WeightLinear(torch.nn.Module):
450    __constants__ = ["in_features", "out_features"]
451
452    in_features: int
453    out_features: int
454    weight: torch.Tensor
455
456    """
457    This module implements a dynamic quantized linear layer with int8 weight.
458    Weights are per channel quantized. Parameters of importance
459    precision: precision of input and output. e.g. torch.float32 means input
460    activation is float32 and output is float32.
461    """
462
463    def __init__(
464        self,
465        in_features: int,
466        out_features: int,
467        bias=True,
468        device=None,
469        dtype=None,
470        precision: torch.dtype = torch.float32,
471    ) -> None:
472        super().__init__()
473        self.in_features = in_features
474        self.out_features = out_features
475        assert not bias, "require bias=False"
476        self.precision = precision
477
478        if dtype is not None:
479            raise ValueError("Please specify 'precision' instead of 'dtype'")
480
481        # currently storing unpacked int8 weights
482        self.register_buffer(
483            "weight",
484            torch.zeros((out_features, in_features), dtype=torch.int8),
485        )
486        self.register_buffer(
487            "scales",
488            torch.zeros(
489                (out_features),
490                dtype=torch.float32,
491            ),
492        )
493        self.register_buffer(
494            "zeros",
495            torch.zeros(
496                (out_features),
497                dtype=torch.float32,
498            ),
499        )
500
501    def forward(self, input: torch.Tensor) -> torch.Tensor:
502        input = input.to(self.precision)
503        return linear_forward_8da8w(
504            input,
505            self.weight,
506            self.scales,
507            self.zeros,
508            self.out_features,
509            self.precision,
510        )
511
512
513#########################################################################
514#####                   embedding table quantization               ######
515
516
517def replace_embedding_weight_only_grouped_int8_per_channel(
518    module, device, bitwidth: int = 8, group_size: Optional[int] = None, packed=False
519):
520    for name, child in module.named_children():
521        # print(f"name: {name}")
522        if isinstance(child, nn.Embedding):
523            # print(f"{name, child}")
524            # print(f"weights size: {child.weight.size()}")
525            setattr(
526                module,
527                name,
528                QuantizedGroupEmbedding(
529                    device=device,
530                    vocab_size=child.weight.shape[0],
531                    embedding_dim=child.weight.shape[1],
532                    group_size=group_size,
533                    dtype=child.weight.dtype,
534                    packed=packed,
535                    bitwidth=bitwidth,
536                ),
537            )
538        else:
539            replace_embedding_weight_only_grouped_int8_per_channel(
540                child, device, bitwidth, group_size, packed
541            )
542
543
544class EmbeddingQuantHandler(QuantHandler):
545    def __init__(
546        self,
547        mod,
548        device="cpu",
549        *,
550        bitwidth: int = 8,
551        group_size: Optional[int] = None,
552        packed=False,
553    ):
554        if isinstance(packed, str):
555            packed = packed == "True"
556        self.mod = mod
557        self.device = device
558        self.group_size = group_size
559        self.bitwidth = bitwidth
560        self.packed = packed
561        if (bitwidth not in [2, 4]) and packed:
562            raise RuntimeError("pack only works with bitsize 2, 4")
563
564    @torch.no_grad()
565    def create_quantized_state_dict(self, packed=False) -> Dict:
566        cur_state_dict = self.mod.state_dict()
567
568        if self.bitwidth == 2:
569            range_min = -2
570            range_max = 1
571        elif self.bitwidth == 4:
572            range_min = -8
573            range_max = 7
574        elif self.bitwidth == 8:
575            range_min = -128
576            range_max = 127
577        else:
578            raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
579
580        for fqn, mod in self.mod.named_modules():
581            if isinstance(mod, nn.Embedding):
582                # print("****")
583                # print(f"Embedding identified: {fqn, mod}")
584                # print(f"weights size: {mod.weight.size()}")
585                # print(f"quantize {fqn}...")
586
587                print(
588                    f"quantize {fqn, mod} with group_size {self.group_size}, bitwidth {self.bitwidth}"
589                )
590                weight, scales, _ = dynamically_quantize_per_channel(
591                    mod.weight.float(),
592                    range_min,
593                    range_max,
594                    torch.int8,
595                    self.group_size,
596                    scales_dtype=mod.weight.dtype,
597                )
598
599                if packed:
600                    if self.bitwidth == 2:
601                        if weight.shape[-1] % 4 != 0:
602                            raise RuntimeError("automatic padding not implemented yet")
603                        weight_range_shifted = weight.add(2).view(torch.uint8)
604                        weight_view = weight_range_shifted.view(
605                            weight.shape[0], weight.shape[1] // 4, 4
606                        )
607                        weight_0 = weight_view[:, :, 0]
608                        weight_1 = weight_view[:, :, 1] << 2
609                        weight_2 = weight_view[:, :, 2] << 4
610                        weight_3 = weight_view[:, :, 3] << 6
611                        weight_packed = weight_0 + weight_1 + weight_2 + weight_3
612                        weight = weight_packed
613                    elif self.bitwidth == 4:
614                        if weight.shape[-1] % 2 != 0:
615                            raise RuntimeError("automatic padding not implemented yet")
616                        weight_range_shifted = weight.add(8).view(torch.uint8)
617                        weight_view = weight_range_shifted.view(
618                            weight.shape[0], weight.shape[1] // 2, 2
619                        )
620                        weight_even = weight_view[:, :, 0] * 16  # left shift 4
621                        weight_odd = weight_view[:, :, 1]
622                        weight_packed = weight_even + weight_odd
623                        weight = weight_packed
624
625                weight = weight.to(device=self.device)
626                scales = scales.to(device=self.device)
627                # Update state dict
628                cur_state_dict[f"{fqn}.weight"] = weight
629                # squeeze makes group_size=rowsize unidimensional
630                cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
631
632        return cur_state_dict
633
634    def convert_for_runtime(self) -> nn.Module:
635        replace_embedding_weight_only_grouped_int8_per_channel(
636            self.mod, self.device, self.bitwidth, self.group_size, self.packed
637        )
638        return self.mod
639
640    def quantized_model(self) -> nn.Module:
641        model_updated_state_dict = self.create_quantized_state_dict(self.packed)
642        self.convert_for_runtime()
643        self.mod.load_state_dict(model_updated_state_dict)
644        return self.mod
645
646
647class QuantizedGroupEmbedding(torch.nn.Module):
648    def __init__(
649        self,
650        device,
651        vocab_size: int,
652        embedding_dim: int,
653        group_size: Optional[int] = None,
654        dtype=torch.half,
655        packed=False,
656        bitwidth: int = 8,
657    ) -> None:
658        super().__init__()
659        if group_size is None or group_size == 0:
660            group_size = embedding_dim
661        self.group_size = group_size
662        self.dtype = dtype
663        self.packed = packed
664        self.bitwidth = bitwidth
665        if not packed:
666            self.register_buffer(
667                "weight",
668                torch.zeros(
669                    (vocab_size, embedding_dim), dtype=torch.int8, device=device
670                ),
671            )
672        else:  # packed
673            if bitwidth == 2:
674                self.register_buffer(
675                    "weight",
676                    torch.zeros(
677                        (vocab_size, embedding_dim // 4),
678                        dtype=torch.uint8,
679                        device=device,
680                    ),
681                )
682            elif bitwidth == 4:
683                self.register_buffer(
684                    "weight",
685                    torch.zeros(
686                        (vocab_size, embedding_dim // 2),
687                        dtype=torch.uint8,
688                        device=device,
689                    ),
690                )
691
692        groups_per_row = (embedding_dim + group_size - 1) // group_size
693        if groups_per_row > 1:
694            self.register_buffer(
695                "scales",
696                torch.ones(
697                    (vocab_size, groups_per_row), dtype=torch.float16, device=device
698                ),
699            )
700        else:
701            self.register_buffer(
702                "scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
703            )
704
705    @torch.no_grad()
706    def forward(self, indices: torch.Tensor) -> torch.Tensor:
707        if not self.packed:  # 8bit
708            return torch.ops.quantized_decomposed.embedding_byte.dtype(
709                self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
710            )
711        else:  # packed
712            if self.bitwidth == 2:
713                return torch.ops.quantized_decomposed.embedding_2bit.dtype(
714                    self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
715                )
716
717            # Remaining case (always return to make pyre happy)
718            assert self.bitwidth == 4
719            return torch.ops.quantized_decomposed.embedding_4bit.dtype(
720                self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
721            )
722
723
724############################ Source Transform Start #######################
725
726
727def get_quant_embedding_transform(args):
728    if args.embedding_quantize.startswith("torchao:"):
729        bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
730        group_size = int(group_size)
731        bitwidth = int(bitwidth)
732        _load_torchao_ops_aten()
733        from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer
734
735        def _torchao_embedding_quantizer(model):
736            with torch.no_grad():
737                model = IntxWeightEmbeddingQuantizer(
738                    device="cpu",
739                    precision=torch.float32,
740                    bitwidth=bitwidth,
741                    groupsize=group_size,
742                ).quantize(model)
743            return model
744
745        return _torchao_embedding_quantizer
746
747    bitwidth, group_size = args.embedding_quantize.split(",")
748    if group_size == "none" or group_size == "None" or group_size == "0":
749        group_size = None
750    else:
751        group_size = int(group_size)
752    bitwidth = int(bitwidth)
753    return lambda model: EmbeddingQuantHandler(
754        model,
755        bitwidth=bitwidth,
756        group_size=group_size,
757        packed=(bitwidth in [2, 4]),
758    ).quantized_model()
759
760
761def get_quant_weight_transform(args, dtype_override, verbose):
762    # If these optional args are None, don't provide them to quantize()
763    quant_args_str = [
764        "group_size",
765        "calibration_tasks",
766        "calibration_limit",
767        "calibration_seq_length",
768    ]
769    arg_dict = vars(args)
770    quant_args = {
771        param: val
772        for param in quant_args_str
773        if (val := arg_dict.get(param)) is not None
774    }
775
776    return partial(
777        quantize,
778        **quant_args,
779        qmode=args.quantization_mode,
780        activation_dtype=dtype_override,
781        checkpoint_path=(Path(path) if (path := args.checkpoint) is not None else None),
782        tokenizer_path=(
783            Path(path) if (path := args.tokenizer_path) is not None else None
784        ),
785    )
786
787
788def _load_torchao_ops_aten():
789    import glob
790    import os
791
792    libs = glob.glob(
793        os.path.abspath(
794            os.path.join(
795                os.environ.get("CMAKE_INSTALL_PREFIX", ""),
796                "lib/libtorchao_ops_aten.*",
797            )
798        )
799    )
800    assert (
801        len(libs) == 1
802    ), f"Expected 1 library but got {len(libs)}.  If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
803    logging.info(f"Loading custom ops library: {libs[0]}")
804    torch.ops.load_library(libs[0])
805
806
807############################ Source Transform End #######################
808