1# pyre-unsafe 2import logging 3from typing import Any, Callable, Dict, Optional, Type 4 5import executorch.backends.vulkan.custom_ops_lib # noqa 6 7import torch 8import torch.nn.functional as F 9 10from torchao.quantization.GPTQ import _check_linear_int4_k 11from torchao.quantization.unified import Quantizer 12from torchao.quantization.utils import groupwise_affine_quantize_tensor 13 14 15# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with 16# changes at the annotated lines. 17class VkWeightOnlyInt4Linear(torch.nn.Module): 18 __constants__ = ["in_features", "out_features"] 19 in_features: int 20 out_features: int 21 weight: torch.Tensor 22 23 def __init__( 24 self, 25 in_features: int, 26 out_features: int, 27 # TODO: remove dtype field, not used 28 bias=False, 29 device=None, 30 dtype=None, 31 groupsize: int = 128, 32 inner_k_tiles: int = 8, 33 precision: torch.dtype = torch.bfloat16, 34 scales_precision: torch.dtype = torch.bfloat16, 35 ) -> None: 36 super().__init__() 37 self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) 38 if self.padding: 39 from torchao.utils import find_multiple 40 41 self.origin_in_features = in_features 42 in_features = find_multiple(in_features, (1024,)) 43 44 self.in_features = in_features 45 self.out_features = out_features 46 assert not bias, "require bias=False" 47 self.device = device 48 self.groupsize = groupsize 49 self.inner_k_tiles = inner_k_tiles 50 self.precision = precision 51 self.scales_precision = scales_precision 52 53 if dtype is not None: 54 raise ValueError("Please specify 'precision' instead of 'dtype'") 55 56 assert out_features % 8 == 0, "require out_features % 8 == 0" 57 assert ( 58 in_features % (inner_k_tiles * 16) == 0 59 ), "require in_features % (innerKTiles * 16) == 0" 60 # In the original implementation, the weight buffer is registered with the packed 61 # sizes, i.e. the result of calling the _convert_weight_to_int4pack operator. 62 # However, the Vulkan implementation does not expect the weights to be packed 63 # therefore the weight tensor is registered with the unpacked sizes instead. 64 # Note that in_features is divided by 2 because each `uint8` tensor element 65 # contains 2 4-bit packed values. 66 self.register_buffer( 67 "weight", 68 torch.empty( 69 (out_features, in_features // 2), 70 dtype=torch.uint8, 71 device=device, 72 ), 73 ) 74 self.dtype = dtype 75 self.register_buffer( 76 "scales_and_zeros", 77 torch.empty( 78 (in_features // groupsize, out_features, 2), 79 dtype=self.scales_precision, 80 device=device, 81 ), 82 ) 83 84 def forward(self, input: torch.Tensor) -> torch.Tensor: 85 if self.padding: 86 input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) 87 # The forward method is replaced. In the original implementation, the forward 88 # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom 89 # operator is called instead. 90 return torch.ops.et_vk.linear_weight_int4( 91 input, 92 self.weight, 93 self.groupsize, 94 self.scales_and_zeros, 95 self.inner_k_tiles, 96 ) 97 98 99# This function is coped from torchao.quantization.GPTQ._replace_linear_int4 100# with small changes at the annotated locations. 101def _vk_replace_linear_int4( 102 module: torch.nn.Module, 103 groupsize: int, 104 inner_k_tiles: Optional[int], 105 padding_allowed: bool, 106 skip_layer_func: Optional[Callable] = None, 107 precision: torch.dtype = torch.bfloat16, 108 scales_precision: torch.dtype = torch.bfloat16, 109 # Use custom vulkan linear layer as default 110 linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, 111 copy_weights: bool = False, 112 # Serves the same purpose as `tensor_dim_limit` in 113 # executorch.backends.vulkan.partitioner.VulkanSupportedOperators 114 feature_limit: int = 16384, 115): 116 for name, child in module.named_children(): 117 if isinstance(child, torch.nn.Linear) and ( 118 skip_layer_func is None or not skip_layer_func(child.weight) 119 ): 120 # Add an additional condition that the out/in features must not exceed the 121 # `feature_limit` argument. 122 if ( 123 _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) 124 or padding_allowed 125 ) and ( 126 child.out_features < feature_limit and child.in_features < feature_limit 127 ): 128 new_linear = linear_class( 129 child.in_features, 130 child.out_features, 131 bias=False, 132 device=child.weight.device, 133 groupsize=groupsize, 134 inner_k_tiles=inner_k_tiles, 135 precision=precision, 136 scales_precision=scales_precision, 137 ) 138 if copy_weights and child.weight.device != torch.device("meta"): 139 # pyre-fixme[16]: `Module` has no attribute `weight`. 140 new_linear.weight = child.weight 141 setattr(module, name, new_linear) 142 else: 143 _vk_replace_linear_int4( 144 child, 145 groupsize, 146 inner_k_tiles, 147 padding_allowed, 148 skip_layer_func, 149 precision, 150 scales_precision, 151 linear_class, 152 copy_weights, 153 ) 154 155 156# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer 157# with some changes at the annotated lines. 158class VkInt4WeightOnlyQuantizer(Quantizer): 159 def __init__( 160 self, 161 groupsize: int = 256, 162 padding_allowed: bool = True, 163 inner_k_tiles: Optional[int] = 8, 164 device: torch.device = torch.device("cpu"), # noqa 165 precision: torch.dtype = torch.float32, 166 feature_limit: int = 16384, 167 ) -> None: 168 super().__init__() 169 assert inner_k_tiles in [2, 4, 8] 170 assert groupsize in [32, 64, 128, 256] 171 172 self.inner_k_tiles = inner_k_tiles 173 self.groupsize: int = groupsize 174 self.padding_allowed: bool = padding_allowed 175 self.device: torch.device = device 176 self.precision: torch.dtype = precision 177 # Serves the same purpose as `tensor_dim_limit` in 178 # executorch.backends.vulkan.partitioner.VulkanSupportedOperators 179 self.feature_limit = feature_limit 180 181 @torch.no_grad() 182 def _create_quantized_state_dict( 183 self, model: torch.nn.Module 184 ) -> Dict[str, torch.Tensor]: 185 cur_state_dict = model.state_dict() 186 for fqn, mod in model.named_modules(): 187 # Add additional check to make sure features do not exceed feature limit 188 if isinstance(mod, torch.nn.Linear) and ( 189 mod.out_features < self.feature_limit 190 and mod.in_features < self.feature_limit 191 ): 192 assert not mod.bias 193 out_features = mod.out_features 194 in_features = mod.in_features 195 logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") 196 197 assert ( 198 in_features % self.groupsize == 0 199 ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" 200 201 weight = mod.weight.data 202 if not _check_linear_int4_k( 203 in_features, self.groupsize, self.inner_k_tiles 204 ): 205 if self.padding_allowed: 206 import torch.nn.functional as F 207 208 from torchao.utils import find_multiple 209 210 logging.warn( 211 f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" 212 ) 213 padded_in_features = find_multiple(in_features, (1024,)) 214 weight = F.pad( 215 weight, pad=(0, padded_in_features - in_features) 216 ) 217 else: 218 logging.warn( 219 f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " 220 + "and that groupsize and inner_k_tiles*16 evenly divide into it" 221 ) 222 continue 223 (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( 224 weight, 225 4, # n_bit 226 self.groupsize, 227 self.precision, # dtype for scales_and_zeros 228 ) 229 # In the original implementation, w_int4x8 is packed via calling the 230 # _convert_weight_to_int4pack operator before storing the weight. However 231 # the Vulkan implementation does not expect the weights to be packed, so 232 # the w_int4x8 tensor is stored as the weight instead. 233 cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device) 234 cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( 235 self.device 236 ) 237 return cur_state_dict 238 239 def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: 240 _vk_replace_linear_int4( 241 model, 242 self.groupsize, 243 self.inner_k_tiles, 244 self.padding_allowed, 245 skip_layer_func=None, 246 precision=self.precision, 247 scales_precision=self.precision, 248 ) 249 return model 250 251 def quantize( 252 self, model: torch.nn.Module, *args: Any, **kwargs: Any 253 ) -> torch.nn.Module: 254 state_dict = self._create_quantized_state_dict(model) 255 model = self._convert_for_runtime(model) 256 model.load_state_dict(state_dict, strict=False) 257 return model 258