1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import typing 8 9import torch 10 11 12def rotate_embeddings(model, R1: torch.Tensor) -> None: 13 # Rotate the embeddings. 14 for W in [model.tok_embeddings]: 15 dtype = W.weight.data.dtype 16 W_ = W.weight.data.to(device="cpu", dtype=torch.float32) 17 W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 18 19 20def rotate_attention_inputs(layer, R1) -> None: 21 # Rotate the WQ, WK and WV matrices of the self-attention layer. 22 for W in [layer.attention.wq, layer.attention.wk, layer.attention.wv]: 23 dtype = W.weight.dtype 24 W_ = W.weight.to(device="cpu", dtype=torch.float32) 25 W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 26 27 28def rotate_attention_output(layer, R1) -> None: 29 # Rotate output matrix of the self-attention layer. 30 W = layer.attention.wo 31 dtype = W.weight.data.dtype 32 W_ = W.weight.data.to(device="cpu", dtype=torch.float32) 33 W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) 34 if W.bias is not None: 35 b = W.bias.data.to(device="cpu", dtype=torch.float32) 36 W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) 37 38 39def rotate_mlp_input(layer, R1): 40 # Rotate the MLP input weights. 41 mlp_inputs = [layer.feed_forward.w3, layer.feed_forward.w1] 42 for W in mlp_inputs: 43 dtype = W.weight.dtype 44 W_ = W.weight.data.to(device="cpu", dtype=torch.float32) 45 W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 46 47 48def rotate_mlp_output(layer, R1): 49 # Rotate the MLP output weights and bias. 50 W = layer.feed_forward.w2 51 dtype = W.weight.data.dtype 52 W_ = W.weight.data.to(device="cpu", dtype=torch.float32) 53 W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) 54 55 if W.bias is not None: 56 b = W.bias.data.to(device="cpu", dtype=torch.float32) 57 W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) 58 59 60def rotate_head(model, R1: torch.Tensor) -> None: 61 # Rotate the head. 62 W = model.output 63 dtype = W.weight.data.dtype 64 W_ = W.weight.data.to(device="cpu", dtype=torch.float32) 65 W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) 66 67 68def rotate_ov_proj(layer, head_dim, R2=None): 69 W = layer.attention.wv 70 dtype = W.weight.data.dtype 71 W_ = W.weight.data.to(device="cpu", dtype=torch.float32).t() 72 transposed_shape = W_.shape 73 temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim) 74 temp = temp.to(torch.float32) @ R2 75 W_ = temp.reshape(transposed_shape).t() 76 W.weight.data = W_.to(device="cpu", dtype=dtype) 77 78 W = layer.attention.wo 79 dtype = W.weight.data.dtype 80 W_ = W.weight.data.to(device="cpu", dtype=torch.float32) 81 init_shape = W_.shape 82 temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim) 83 temp = temp.to(torch.float32) @ R2 84 W_ = temp.reshape(init_shape) 85 W.weight.data = W_.to(device="cpu", dtype=dtype) 86 87 88def cleanup_memory() -> None: 89 """Run GC and clear GPU memory.""" 90 import gc 91 92 # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed 93 gc.collect() 94 95 96def get_model_with_r1_r2(optimized_rotation_path: str): 97 return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path) 98 99 100def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str): 101 optimized_rotation = torch.load(optimized_rotation_path, weights_only=True) 102 R1 = optimized_rotation["R1"].to(torch.float32) 103 config = model.params 104 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute 105 # `n_heads`. 106 num_heads = config.n_heads 107 head_dim = config.dim // num_heads 108 109 rotate_embeddings(model, R1) 110 rotate_head(model, R1) 111 cleanup_memory() 112 113 # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got 114 # `Union[Tensor, Module]`. 115 for idx, layer in enumerate(model.layers): 116 key = f"model.layers.{idx}.self_attn.R2" 117 R2 = optimized_rotation[key].to(torch.float32) 118 rotate_attention_inputs(layer, R1) 119 rotate_attention_output(layer, R1) 120 rotate_mlp_input(layer, R1) 121 rotate_mlp_output(layer, R1) 122 rotate_ov_proj(layer, head_dim, R2=R2) 123 return model 124 125 126def fuse_ln_linear( 127 layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear] 128) -> None: 129 """ 130 fuse the linear operations in Layernorm into the adjacent linear blocks. 131 """ 132 for linear in linear_layers: 133 linear_dtype = linear.weight.dtype 134 135 # Calculating new weight and bias 136 W_ = linear.weight.data.to(dtype=torch.float32) 137 # pyre-fixme[58]: `*` is not supported for operand types `Tensor` and 138 # `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`. 139 linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to( 140 linear_dtype 141 ) 142 143 if hasattr(layernorm, "bias"): 144 if linear.bias is None: 145 linear.bias = torch.nn.Parameter( 146 torch.zeros(linear.out_features, dtype=torch.float32) 147 ) 148 linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul( 149 # pyre-fixme[6]: For 2nd argument expected `Tensor` but got 150 # `Union[Tensor, Module]`. 151 W_, layernorm.bias.to(dtype=torch.float32) 152 ) 153 linear.bias.data = linear.bias.data.to(linear_dtype) 154 155 156def fuse_layer_norms(model: torch.nn.Module): 157 # Embedding fusion 158 for W in [model.tok_embeddings]: 159 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute 160 # `weight`. 161 W_ = W.weight.data.to(dtype=torch.float32) 162 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute 163 # `weight`. 164 W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) 165 166 # Fuse the linear operations in Layernorm into the adjacent linear blocks. 167 # pyre-fixme[29]: 168 # `Union[BoundMethod[typing.Callable(torch._tensor.Tensor.__iter__)[[Named(self, 169 # torch._tensor.Tensor)], typing.Any], torch._tensor.Tensor], 170 # torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function. 171 for layer in model.layers: 172 # fuse the input layernorms into the linear layers 173 fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1]) 174 fuse_ln_linear( 175 layer.attention_norm, 176 [ 177 layer.attention.wq, 178 layer.attention.wk, 179 layer.attention.wv, 180 ], 181 ) 182 183 W_norm = layer.ffn_norm.weight.data 184 layer.ffn_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) 185 W_norm = layer.attention_norm.weight.data 186 layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) 187 188 fuse_ln_linear( 189 # pyre-fixme[6]: For 1st argument expected `Module` but got `Union[Tensor, 190 # Module]`. 191 model.norm, 192 # pyre-fixme[6]: For 2nd argument expected `Iterable[Linear]` but got 193 # `Iterable[Union[Tensor, Module]]`. 194 [model.output], 195 ) 196 # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute 197 # `weight`. 198 W_norm = model.norm.weight.data 199 model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) 200 201 return model 202