1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import typing
8
9import torch
10
11
12def rotate_embeddings(model, R1: torch.Tensor) -> None:
13    # Rotate the embeddings.
14    for W in [model.tok_embeddings]:
15        dtype = W.weight.data.dtype
16        W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
17        W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
18
19
20def rotate_attention_inputs(layer, R1) -> None:
21    # Rotate the WQ, WK and WV matrices of the self-attention layer.
22    for W in [layer.attention.wq, layer.attention.wk, layer.attention.wv]:
23        dtype = W.weight.dtype
24        W_ = W.weight.to(device="cpu", dtype=torch.float32)
25        W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
26
27
28def rotate_attention_output(layer, R1) -> None:
29    # Rotate output matrix of the self-attention layer.
30    W = layer.attention.wo
31    dtype = W.weight.data.dtype
32    W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
33    W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
34    if W.bias is not None:
35        b = W.bias.data.to(device="cpu", dtype=torch.float32)
36        W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)
37
38
39def rotate_mlp_input(layer, R1):
40    # Rotate the MLP input weights.
41    mlp_inputs = [layer.feed_forward.w3, layer.feed_forward.w1]
42    for W in mlp_inputs:
43        dtype = W.weight.dtype
44        W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
45        W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
46
47
48def rotate_mlp_output(layer, R1):
49    # Rotate the MLP output weights and bias.
50    W = layer.feed_forward.w2
51    dtype = W.weight.data.dtype
52    W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
53    W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype)
54
55    if W.bias is not None:
56        b = W.bias.data.to(device="cpu", dtype=torch.float32)
57        W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype)
58
59
60def rotate_head(model, R1: torch.Tensor) -> None:
61    # Rotate the head.
62    W = model.output
63    dtype = W.weight.data.dtype
64    W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
65    W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype)
66
67
68def rotate_ov_proj(layer, head_dim, R2=None):
69    W = layer.attention.wv
70    dtype = W.weight.data.dtype
71    W_ = W.weight.data.to(device="cpu", dtype=torch.float32).t()
72    transposed_shape = W_.shape
73    temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim)
74    temp = temp.to(torch.float32) @ R2
75    W_ = temp.reshape(transposed_shape).t()
76    W.weight.data = W_.to(device="cpu", dtype=dtype)
77
78    W = layer.attention.wo
79    dtype = W.weight.data.dtype
80    W_ = W.weight.data.to(device="cpu", dtype=torch.float32)
81    init_shape = W_.shape
82    temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim)
83    temp = temp.to(torch.float32) @ R2
84    W_ = temp.reshape(init_shape)
85    W.weight.data = W_.to(device="cpu", dtype=dtype)
86
87
88def cleanup_memory() -> None:
89    """Run GC and clear GPU memory."""
90    import gc
91
92    # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed
93    gc.collect()
94
95
96def get_model_with_r1_r2(optimized_rotation_path: str):
97    return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path)
98
99
100def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str):
101    optimized_rotation = torch.load(optimized_rotation_path, weights_only=True)
102    R1 = optimized_rotation["R1"].to(torch.float32)
103    config = model.params
104    # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
105    #  `n_heads`.
106    num_heads = config.n_heads
107    head_dim = config.dim // num_heads
108
109    rotate_embeddings(model, R1)
110    rotate_head(model, R1)
111    cleanup_memory()
112
113    # pyre-fixme[6]: For 1st argument expected `Iterable[Variable[_T]]` but got
114    #  `Union[Tensor, Module]`.
115    for idx, layer in enumerate(model.layers):
116        key = f"model.layers.{idx}.self_attn.R2"
117        R2 = optimized_rotation[key].to(torch.float32)
118        rotate_attention_inputs(layer, R1)
119        rotate_attention_output(layer, R1)
120        rotate_mlp_input(layer, R1)
121        rotate_mlp_output(layer, R1)
122        rotate_ov_proj(layer, head_dim, R2=R2)
123    return model
124
125
126def fuse_ln_linear(
127    layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear]
128) -> None:
129    """
130    fuse the linear operations in Layernorm into the adjacent linear blocks.
131    """
132    for linear in linear_layers:
133        linear_dtype = linear.weight.dtype
134
135        # Calculating new weight and bias
136        W_ = linear.weight.data.to(dtype=torch.float32)
137        # pyre-fixme[58]: `*` is not supported for operand types `Tensor` and
138        #  `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`.
139        linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to(
140            linear_dtype
141        )
142
143        if hasattr(layernorm, "bias"):
144            if linear.bias is None:
145                linear.bias = torch.nn.Parameter(
146                    torch.zeros(linear.out_features, dtype=torch.float32)
147                )
148            linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul(
149                # pyre-fixme[6]: For 2nd argument expected `Tensor` but got
150                #  `Union[Tensor, Module]`.
151                W_, layernorm.bias.to(dtype=torch.float32)
152            )
153            linear.bias.data = linear.bias.data.to(linear_dtype)
154
155
156def fuse_layer_norms(model: torch.nn.Module):
157    # Embedding fusion
158    for W in [model.tok_embeddings]:
159        # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
160        #  `weight`.
161        W_ = W.weight.data.to(dtype=torch.float32)
162        # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
163        #  `weight`.
164        W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype)
165
166    # Fuse the linear operations in Layernorm into the adjacent linear blocks.
167    # pyre-fixme[29]:
168    #  `Union[BoundMethod[typing.Callable(torch._tensor.Tensor.__iter__)[[Named(self,
169    #  torch._tensor.Tensor)], typing.Any], torch._tensor.Tensor],
170    #  torch._tensor.Tensor, torch.nn.modules.module.Module]` is not a function.
171    for layer in model.layers:
172        # fuse the input layernorms into the linear layers
173        fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1])
174        fuse_ln_linear(
175            layer.attention_norm,
176            [
177                layer.attention.wq,
178                layer.attention.wk,
179                layer.attention.wv,
180            ],
181        )
182
183        W_norm = layer.ffn_norm.weight.data
184        layer.ffn_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
185        W_norm = layer.attention_norm.weight.data
186        layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
187
188    fuse_ln_linear(
189        # pyre-fixme[6]: For 1st argument expected `Module` but got `Union[Tensor,
190        #  Module]`.
191        model.norm,
192        # pyre-fixme[6]: For 2nd argument expected `Iterable[Linear]` but got
193        #  `Iterable[Union[Tensor, Module]]`.
194        [model.output],
195    )
196    # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
197    #  `weight`.
198    W_norm = model.norm.weight.data
199    model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32)
200
201    return model
202