xref: /aosp_15_r20/external/executorch/examples/models/llama/source_transformation/prune_vocab.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Dict
8
9import numpy as np
10
11import torch
12
13
14def prune_output_vocab(
15    model: torch.nn.Module,
16    token_map: Dict[int, int],
17    output_layer_name: str = "output",
18) -> torch.nn.Module:
19    """Prune the model output linear layer while keeping the tokens in the token map.
20
21    Note: Pruning is performed in-place.
22
23    Args:
24        model: The model to prune.
25        token_map: A dictionary mapping from new token ids to the old token ids to preserve.
26            e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
27        output_layer_name: name of the output layer to prune
28
29    Returns:
30        The pruned model.
31    """
32    assert hasattr(
33        model, output_layer_name
34    ), f"Model does not have {output_layer_name} layer"
35    output_layer = getattr(model, output_layer_name)
36    assert isinstance(
37        output_layer, torch.nn.Linear
38    ), "Output layer is not a linear layer"
39    original_shape = output_layer.weight.shape
40    input_features = original_shape[1]
41    num_pruned_tokens = len(token_map)
42    has_bias = output_layer.bias is not None
43    weight_dtype = output_layer.weight.dtype
44    pruned_layer = torch.nn.Linear(input_features, num_pruned_tokens, bias=has_bias)
45    pruned_layer.to(dtype=weight_dtype)
46    pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
47    pruned_layer_bias = None
48    if has_bias:
49        pruned_layer_bias = np.zeros(pruned_layer.bias.shape, dtype=np.float32)
50    for i, token_id in token_map.items():
51        # Copy the weights and biases from the original layer to the pruned layer
52        pruned_wt = output_layer.weight[token_id].detach()
53        if weight_dtype == torch.bfloat16:
54            pruned_wt = pruned_wt.float()
55        pruned_layer_weights[i] = pruned_wt.numpy()
56        if has_bias:
57            pruned_bias = output_layer.bias[token_id].detach()
58            if weight_dtype == torch.bfloat16:
59                pruned_bias = pruned_bias.float()
60            pruned_layer_bias[i] = pruned_bias.numpy()
61    with torch.no_grad():
62        pruned_layer.weight.copy_(
63            torch.tensor(pruned_layer_weights, dtype=weight_dtype)
64        )
65        if has_bias:
66            pruned_layer.bias.copy_(torch.tensor(pruned_layer_bias, dtype=weight_dtype))
67
68    # Replace the original layer with the pruned layer
69    setattr(model, output_layer_name, pruned_layer)
70
71    return model
72
73
74def prune_input_vocab(
75    model: torch.nn.Module,
76    token_map: Dict[int, int],
77    imput_layer_name: str = "tok_embeddings",
78) -> torch.nn.Module:
79    """Prune the model input embedding layer while keeping the tokens in the token map.
80
81    Note: Pruning is performed in-place.
82
83    Args:
84        model: The model to prune.
85        token_map: A dictionary mapping from new token ids to the old token ids to preserve.
86            e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243}
87        imput_layer_name: name of the input embedding layer to prune
88
89    Returns:
90        The pruned model.
91    """
92    assert hasattr(
93        model, imput_layer_name
94    ), f"Model does not have {imput_layer_name} layer"
95    input_layer = getattr(model, imput_layer_name)
96    assert isinstance(
97        input_layer, torch.nn.Embedding
98    ), "Input layer is not an Embedding layer"
99    original_shape = input_layer.weight.shape
100    num_pruned_tokens = len(token_map)
101    weight_dtype = input_layer.weight.dtype
102    pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1])
103    pruned_layer.to(dtype=weight_dtype)
104    pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32)
105    for i, token_id in token_map.items():
106        # Copy the weights from the original layer to the pruned layer
107        pruned_wt = input_layer.weight[token_id].detach()
108        if weight_dtype == torch.bfloat16:
109            pruned_wt = pruned_wt.float()
110        pruned_layer_weights[i] = pruned_wt.numpy()
111    with torch.no_grad():
112        pruned_layer.weight.copy_(
113            torch.tensor(pruned_layer_weights, dtype=weight_dtype)
114        )
115
116    # Replace the original layer with the pruned layer
117    setattr(model, imput_layer_name, pruned_layer)
118
119    return model
120