1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10from executorch.examples.models.llama.llama_transformer import ModelArgs, Transformer 11from executorch.examples.models.llama.source_transformation.pre_quantization import ( 12 sanitize_checkpoint_from_pre_quantization, 13 transform_embedding_for_pre_quantization, 14 transform_linear_for_pre_quantization, 15 transform_output_linear_for_pre_quantization, 16) 17from executorch.examples.models.llama.source_transformation.quantize import ( 18 dynamically_quantize_per_channel, 19) 20from torchao.quantization.utils import group_quantize_tensor_symmetric 21 22 23class PreQuantizationTests(unittest.TestCase): 24 25 def _prepare_dummy_model(self) -> Transformer: 26 model_args = ModelArgs( 27 max_seq_len=2048, 28 max_batch_size=1, 29 use_kv_cache=False, 30 use_sdpa_with_kv_cache_op=False, 31 generate_full_logits=False, 32 enable_dynamic_shape=True, 33 dim=768, 34 multiple_of=32, 35 n_heads=12, 36 n_layers=12, 37 norm_eps=1e-05, 38 vocab_size=32000, 39 ) 40 41 model = Transformer(model_args) 42 43 return model 44 45 def test_transform_linear_for_pre_quantization(self): 46 47 # Step 1: Create llama class with dummy weights 48 model = self._prepare_dummy_model() 49 checkpoint = model.state_dict() 50 51 # Step 2: 52 # Do group-wise quantization and amend the checkpoints with 53 # int8 weight and fp32 scales 54 group_size = 32 55 n_bit = 4 56 scales_precision = torch.float32 57 for fqn, mod in model.named_modules(): 58 if isinstance(mod, torch.nn.Linear): 59 weight = mod.weight.data 60 ( 61 weight_int8, 62 scales, 63 zeros, 64 ) = group_quantize_tensor_symmetric( 65 weight.to(torch.float32), n_bit, group_size, scales_precision 66 ) 67 checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu") 68 checkpoint[f"{fqn}.scales"] = scales.to("cpu") 69 70 # Step 3: 71 # Transform the model so that it is compatible with the new checkpoint 72 transform_linear_for_pre_quantization( 73 model, 74 checkpoint, 75 32, 76 torch.float32, 77 ) 78 sanitize_checkpoint_from_pre_quantization(checkpoint) 79 80 model.load_state_dict( 81 checkpoint, 82 strict=False, 83 assign=True, 84 ) 85 86 new_checkpoint = model.state_dict() 87 88 for k, v in checkpoint.items(): 89 # The new_checkpoint contains zeros so 90 # have to iterate over the keys. 91 self.assertTrue(torch.allclose(new_checkpoint[k], v)) 92 93 def test_transform_output_linear_for_pre_quantization(self): 94 # Step 1: Create llama class with dummy weights 95 model = self._prepare_dummy_model() 96 checkpoint = model.state_dict() 97 98 # Step 2: 99 # Do per-channel quantization and amend the checkpoints with 100 # int8 weight and fp32 scales 101 for fqn, mod in model.named_modules(): 102 if isinstance(mod, torch.nn.Linear) and fqn == "output": 103 weight = mod.weight.data 104 weight_int8, scales, _ = dynamically_quantize_per_channel( 105 weight, 106 quant_min=-128, 107 quant_max=127, 108 target_dtype=torch.int8, 109 scales_dtype=torch.float32, 110 ) 111 checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu") 112 checkpoint[f"{fqn}.scales"] = scales.to("cpu") 113 114 # Step 3: 115 # Transform the model so that it is compatible with the new checkpoint 116 transform_output_linear_for_pre_quantization( 117 model, 118 checkpoint, 119 torch.float32, 120 ) 121 sanitize_checkpoint_from_pre_quantization(checkpoint) 122 123 model.load_state_dict( 124 checkpoint, 125 strict=False, 126 assign=True, 127 ) 128 129 new_checkpoint = model.state_dict() 130 131 for k, v in checkpoint.items(): 132 # The new_checkpoint contains zeros so 133 # have to iterate over the keys. 134 self.assertTrue(torch.allclose(new_checkpoint[k], v)) 135 136 def test_transform_embedding_for_pre_quantization(self): 137 138 # Step 1: Create llama class with dummy weights 139 model = self._prepare_dummy_model() 140 checkpoint = model.state_dict() 141 142 # Step 2: 143 # Do group-wise quantization and amend the checkpoints with 144 # int8 weight and fp32 scales 145 group_size = 32 146 n_bit = 4 147 scales_precision = torch.float32 148 for fqn, mod in model.named_modules(): 149 # Quantize everything except the last layer 150 if isinstance(mod, torch.nn.Embedding): 151 weight = mod.weight.data 152 ( 153 weight_int8, 154 scales, 155 zeros, 156 ) = group_quantize_tensor_symmetric( 157 weight.to(torch.float32), n_bit, group_size, scales_precision 158 ) 159 checkpoint[f"{fqn}.weight"] = weight_int8.to("cpu") 160 checkpoint[f"{fqn}.scales"] = scales.to("cpu") 161 162 # Step 3: 163 # Transform the model so that it is compatible with the new checkpoint 164 transform_embedding_for_pre_quantization( 165 model, 166 checkpoint, 167 torch.float32, 168 n_bit, 169 group_size, 170 ) 171 sanitize_checkpoint_from_pre_quantization(checkpoint) 172 173 model.load_state_dict( 174 checkpoint, 175 strict=False, 176 assign=True, 177 ) 178 179 new_checkpoint = model.state_dict() 180 181 for k, v in checkpoint.items(): 182 # The new_checkpoint contains zeros so 183 # have to iterate over the keys. 184 self.assertTrue(torch.allclose(new_checkpoint[k], v)) 185