xref: /aosp_15_r20/external/executorch/examples/models/llama/tests/test_pre_quantization_transforms.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer
11from executorch.examples.models.llama.source_transformation.pre_quantization import (
12    sanitize_checkpoint_from_pre_quantization,
13    transform_embedding_for_pre_quantization,
14    transform_linear_for_pre_quantization,
15    transform_output_linear_for_pre_quantization,
16)
17from executorch.examples.models.llama.source_transformation.quantize import (
18    dynamically_quantize_per_channel,
19)
20from torchao.quantization.utils import group_quantize_tensor_symmetric
21
22
23class PreQuantizationTests(unittest.TestCase):
24
25    def _prepare_dummy_model(self) -> Transformer:
26        model_args = ModelArgs(
27            max_seq_len=2048,
28            max_batch_size=1,
29            use_kv_cache=False,
30            use_sdpa_with_kv_cache_op=False,
31            generate_full_logits=False,
32            enable_dynamic_shape=True,
33            dim=768,
34            multiple_of=32,
35            n_heads=12,
36            n_layers=12,
37            norm_eps=1e-05,
38            vocab_size=32000,
39        )
40
41        model = Transformer(model_args)
42
43        return model
44
45    def test_transform_linear_for_pre_quantization(self):
46
47        # Step 1: Create llama class with dummy weights
48        model = self._prepare_dummy_model()
49        checkpoint = model.state_dict()
50
51        # Step 2:
52        # Do group-wise quantization and amend the checkpoints with
53        # int8 weight and fp32 scales
54        group_size = 32
55        n_bit = 4
56        scales_precision = torch.float32
57        for fqn, mod in model.named_modules():
58            if isinstance(mod, torch.nn.Linear):
59                weight = mod.weight.data
60                (
61                    weight_int8,
62                    scales,
63                    zeros,
64                ) = group_quantize_tensor_symmetric(
65                    weight.to(torch.float32), n_bit, group_size, scales_precision
66                )
67                checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
68                checkpoint[f"{fqn}.scales"] = scales.to("cpu")
69
70        # Step 3:
71        # Transform the model so that it is compatible with the new checkpoint
72        transform_linear_for_pre_quantization(
73            model,
74            checkpoint,
75            32,
76            torch.float32,
77        )
78        sanitize_checkpoint_from_pre_quantization(checkpoint)
79
80        model.load_state_dict(
81            checkpoint,
82            strict=False,
83            assign=True,
84        )
85
86        new_checkpoint = model.state_dict()
87
88        for k, v in checkpoint.items():
89            # The new_checkpoint contains zeros so
90            # have to iterate over the keys.
91            self.assertTrue(torch.allclose(new_checkpoint[k], v))
92
93    def test_transform_output_linear_for_pre_quantization(self):
94        # Step 1: Create llama class with dummy weights
95        model = self._prepare_dummy_model()
96        checkpoint = model.state_dict()
97
98        # Step 2:
99        # Do per-channel quantization and amend the checkpoints with
100        # int8 weight and fp32 scales
101        for fqn, mod in model.named_modules():
102            if isinstance(mod, torch.nn.Linear) and fqn == "output":
103                weight = mod.weight.data
104                weight_int8, scales, _ = dynamically_quantize_per_channel(
105                    weight,
106                    quant_min=-128,
107                    quant_max=127,
108                    target_dtype=torch.int8,
109                    scales_dtype=torch.float32,
110                )
111                checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
112                checkpoint[f"{fqn}.scales"] = scales.to("cpu")
113
114        # Step 3:
115        # Transform the model so that it is compatible with the new checkpoint
116        transform_output_linear_for_pre_quantization(
117            model,
118            checkpoint,
119            torch.float32,
120        )
121        sanitize_checkpoint_from_pre_quantization(checkpoint)
122
123        model.load_state_dict(
124            checkpoint,
125            strict=False,
126            assign=True,
127        )
128
129        new_checkpoint = model.state_dict()
130
131        for k, v in checkpoint.items():
132            # The new_checkpoint contains zeros so
133            # have to iterate over the keys.
134            self.assertTrue(torch.allclose(new_checkpoint[k], v))
135
136    def test_transform_embedding_for_pre_quantization(self):
137
138        # Step 1: Create llama class with dummy weights
139        model = self._prepare_dummy_model()
140        checkpoint = model.state_dict()
141
142        # Step 2:
143        # Do group-wise quantization and amend the checkpoints with
144        # int8 weight and fp32 scales
145        group_size = 32
146        n_bit = 4
147        scales_precision = torch.float32
148        for fqn, mod in model.named_modules():
149            # Quantize everything except the last layer
150            if isinstance(mod, torch.nn.Embedding):
151                weight = mod.weight.data
152                (
153                    weight_int8,
154                    scales,
155                    zeros,
156                ) = group_quantize_tensor_symmetric(
157                    weight.to(torch.float32), n_bit, group_size, scales_precision
158                )
159                checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu")
160                checkpoint[f"{fqn}.scales"] = scales.to("cpu")
161
162        # Step 3:
163        # Transform the model so that it is compatible with the new checkpoint
164        transform_embedding_for_pre_quantization(
165            model,
166            checkpoint,
167            torch.float32,
168            n_bit,
169            group_size,
170        )
171        sanitize_checkpoint_from_pre_quantization(checkpoint)
172
173        model.load_state_dict(
174            checkpoint,
175            strict=False,
176            assign=True,
177        )
178
179        new_checkpoint = model.state_dict()
180
181        for k, v in checkpoint.items():
182            # The new_checkpoint contains zeros so
183            # have to iterate over the keys.
184            self.assertTrue(torch.allclose(new_checkpoint[k], v))
185