xref: /aosp_15_r20/external/executorch/backends/vulkan/_passes/int4_weight_only_quantizer.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# pyre-unsafe
2import logging
3from typing import Any, Callable, Dict, Optional, Type
4
5import executorch.backends.vulkan.custom_ops_lib  # noqa
6
7import torch
8import torch.nn.functional as F
9
10from torchao.quantization.GPTQ import _check_linear_int4_k
11from torchao.quantization.unified import Quantizer
12from torchao.quantization.utils import groupwise_affine_quantize_tensor
13
14
15# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with
16# changes at the annotated lines.
17class VkWeightOnlyInt4Linear(torch.nn.Module):
18    __constants__ = ["in_features", "out_features"]
19    in_features: int
20    out_features: int
21    weight: torch.Tensor
22
23    def __init__(
24        self,
25        in_features: int,
26        out_features: int,
27        # TODO: remove dtype field, not used
28        bias=False,
29        device=None,
30        dtype=None,
31        groupsize: int = 128,
32        inner_k_tiles: int = 8,
33        precision: torch.dtype = torch.bfloat16,
34        scales_precision: torch.dtype = torch.bfloat16,
35    ) -> None:
36        super().__init__()
37        self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
38        if self.padding:
39            from torchao.utils import find_multiple
40
41            self.origin_in_features = in_features
42            in_features = find_multiple(in_features, (1024,))
43
44        self.in_features = in_features
45        self.out_features = out_features
46        assert not bias, "require bias=False"
47        self.device = device
48        self.groupsize = groupsize
49        self.inner_k_tiles = inner_k_tiles
50        self.precision = precision
51        self.scales_precision = scales_precision
52
53        if dtype is not None:
54            raise ValueError("Please specify 'precision' instead of 'dtype'")
55
56        assert out_features % 8 == 0, "require out_features % 8 == 0"
57        assert (
58            in_features % (inner_k_tiles * 16) == 0
59        ), "require in_features % (innerKTiles * 16) == 0"
60        # In the original implementation, the weight buffer is registered with the packed
61        # sizes, i.e. the result of calling the _convert_weight_to_int4pack operator.
62        # However, the Vulkan implementation does not expect the weights to be packed
63        # therefore the weight tensor is registered with the unpacked sizes instead.
64        # Note that in_features is divided by 2 because each `uint8` tensor element
65        # contains 2 4-bit packed values.
66        self.register_buffer(
67            "weight",
68            torch.empty(
69                (out_features, in_features // 2),
70                dtype=torch.uint8,
71                device=device,
72            ),
73        )
74        self.dtype = dtype
75        self.register_buffer(
76            "scales_and_zeros",
77            torch.empty(
78                (in_features // groupsize, out_features, 2),
79                dtype=self.scales_precision,
80                device=device,
81            ),
82        )
83
84    def forward(self, input: torch.Tensor) -> torch.Tensor:
85        if self.padding:
86            input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
87        # The forward method is replaced. In the original implementation, the forward
88        # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom
89        # operator is called instead.
90        return torch.ops.et_vk.linear_weight_int4(
91            input,
92            self.weight,
93            self.groupsize,
94            self.scales_and_zeros,
95            self.inner_k_tiles,
96        )
97
98
99# This function is coped from torchao.quantization.GPTQ._replace_linear_int4
100# with small changes at the annotated locations.
101def _vk_replace_linear_int4(
102    module: torch.nn.Module,
103    groupsize: int,
104    inner_k_tiles: Optional[int],
105    padding_allowed: bool,
106    skip_layer_func: Optional[Callable] = None,
107    precision: torch.dtype = torch.bfloat16,
108    scales_precision: torch.dtype = torch.bfloat16,
109    # Use custom vulkan linear layer as default
110    linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear,
111    copy_weights: bool = False,
112    # Serves the same purpose as `tensor_dim_limit` in
113    # executorch.backends.vulkan.partitioner.VulkanSupportedOperators
114    feature_limit: int = 16384,
115):
116    for name, child in module.named_children():
117        if isinstance(child, torch.nn.Linear) and (
118            skip_layer_func is None or not skip_layer_func(child.weight)
119        ):
120            # Add an additional condition that the out/in features must not exceed the
121            # `feature_limit` argument.
122            if (
123                _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
124                or padding_allowed
125            ) and (
126                child.out_features < feature_limit and child.in_features < feature_limit
127            ):
128                new_linear = linear_class(
129                    child.in_features,
130                    child.out_features,
131                    bias=False,
132                    device=child.weight.device,
133                    groupsize=groupsize,
134                    inner_k_tiles=inner_k_tiles,
135                    precision=precision,
136                    scales_precision=scales_precision,
137                )
138                if copy_weights and child.weight.device != torch.device("meta"):
139                    # pyre-fixme[16]: `Module` has no attribute `weight`.
140                    new_linear.weight = child.weight
141                setattr(module, name, new_linear)
142        else:
143            _vk_replace_linear_int4(
144                child,
145                groupsize,
146                inner_k_tiles,
147                padding_allowed,
148                skip_layer_func,
149                precision,
150                scales_precision,
151                linear_class,
152                copy_weights,
153            )
154
155
156# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer
157# with some changes at the annotated lines.
158class VkInt4WeightOnlyQuantizer(Quantizer):
159    def __init__(
160        self,
161        groupsize: int = 256,
162        padding_allowed: bool = True,
163        inner_k_tiles: Optional[int] = 8,
164        device: torch.device = torch.device("cpu"),  # noqa
165        precision: torch.dtype = torch.float32,
166        feature_limit: int = 16384,
167    ) -> None:
168        super().__init__()
169        assert inner_k_tiles in [2, 4, 8]
170        assert groupsize in [32, 64, 128, 256]
171
172        self.inner_k_tiles = inner_k_tiles
173        self.groupsize: int = groupsize
174        self.padding_allowed: bool = padding_allowed
175        self.device: torch.device = device
176        self.precision: torch.dtype = precision
177        # Serves the same purpose as `tensor_dim_limit` in
178        # executorch.backends.vulkan.partitioner.VulkanSupportedOperators
179        self.feature_limit = feature_limit
180
181    @torch.no_grad()
182    def _create_quantized_state_dict(
183        self, model: torch.nn.Module
184    ) -> Dict[str, torch.Tensor]:
185        cur_state_dict = model.state_dict()
186        for fqn, mod in model.named_modules():
187            # Add additional check to make sure features do not exceed feature limit
188            if isinstance(mod, torch.nn.Linear) and (
189                mod.out_features < self.feature_limit
190                and mod.in_features < self.feature_limit
191            ):
192                assert not mod.bias
193                out_features = mod.out_features
194                in_features = mod.in_features
195                logging.info(f"linear: {fqn}, in={in_features}, out={out_features}")
196
197                assert (
198                    in_features % self.groupsize == 0
199                ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"
200
201                weight = mod.weight.data
202                if not _check_linear_int4_k(
203                    in_features, self.groupsize, self.inner_k_tiles
204                ):
205                    if self.padding_allowed:
206                        import torch.nn.functional as F
207
208                        from torchao.utils import find_multiple
209
210                        logging.warn(
211                            f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
212                        )
213                        padded_in_features = find_multiple(in_features, (1024,))
214                        weight = F.pad(
215                            weight, pad=(0, padded_in_features - in_features)
216                        )
217                    else:
218                        logging.warn(
219                            f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
220                            + "and that groupsize and inner_k_tiles*16 evenly divide into it"
221                        )
222                        continue
223                (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor(
224                    weight,
225                    4,  # n_bit
226                    self.groupsize,
227                    self.precision,  # dtype for scales_and_zeros
228                )
229                # In the original implementation, w_int4x8 is packed via calling the
230                # _convert_weight_to_int4pack operator before storing the weight. However
231                # the Vulkan implementation does not expect the weights to be packed, so
232                # the w_int4x8 tensor is stored as the weight instead.
233                cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device)
234                cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(
235                    self.device
236                )
237        return cur_state_dict
238
239    def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
240        _vk_replace_linear_int4(
241            model,
242            self.groupsize,
243            self.inner_k_tiles,
244            self.padding_allowed,
245            skip_layer_func=None,
246            precision=self.precision,
247            scales_precision=self.precision,
248        )
249        return model
250
251    def quantize(
252        self, model: torch.nn.Module, *args: Any, **kwargs: Any
253    ) -> torch.nn.Module:
254        state_dict = self._create_quantized_state_dict(model)
255        model = self._convert_for_runtime(model)
256        model.load_state_dict(state_dict, strict=False)
257        return model
258