1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Dict 8 9import numpy as np 10 11import torch 12 13 14def prune_output_vocab( 15 model: torch.nn.Module, 16 token_map: Dict[int, int], 17 output_layer_name: str = "output", 18) -> torch.nn.Module: 19 """Prune the model output linear layer while keeping the tokens in the token map. 20 21 Note: Pruning is performed in-place. 22 23 Args: 24 model: The model to prune. 25 token_map: A dictionary mapping from new token ids to the old token ids to preserve. 26 e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243} 27 output_layer_name: name of the output layer to prune 28 29 Returns: 30 The pruned model. 31 """ 32 assert hasattr( 33 model, output_layer_name 34 ), f"Model does not have {output_layer_name} layer" 35 output_layer = getattr(model, output_layer_name) 36 assert isinstance( 37 output_layer, torch.nn.Linear 38 ), "Output layer is not a linear layer" 39 original_shape = output_layer.weight.shape 40 input_features = original_shape[1] 41 num_pruned_tokens = len(token_map) 42 has_bias = output_layer.bias is not None 43 weight_dtype = output_layer.weight.dtype 44 pruned_layer = torch.nn.Linear(input_features, num_pruned_tokens, bias=has_bias) 45 pruned_layer.to(dtype=weight_dtype) 46 pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32) 47 pruned_layer_bias = None 48 if has_bias: 49 pruned_layer_bias = np.zeros(pruned_layer.bias.shape, dtype=np.float32) 50 for i, token_id in token_map.items(): 51 # Copy the weights and biases from the original layer to the pruned layer 52 pruned_wt = output_layer.weight[token_id].detach() 53 if weight_dtype == torch.bfloat16: 54 pruned_wt = pruned_wt.float() 55 pruned_layer_weights[i] = pruned_wt.numpy() 56 if has_bias: 57 pruned_bias = output_layer.bias[token_id].detach() 58 if weight_dtype == torch.bfloat16: 59 pruned_bias = pruned_bias.float() 60 pruned_layer_bias[i] = pruned_bias.numpy() 61 with torch.no_grad(): 62 pruned_layer.weight.copy_( 63 torch.tensor(pruned_layer_weights, dtype=weight_dtype) 64 ) 65 if has_bias: 66 pruned_layer.bias.copy_(torch.tensor(pruned_layer_bias, dtype=weight_dtype)) 67 68 # Replace the original layer with the pruned layer 69 setattr(model, output_layer_name, pruned_layer) 70 71 return model 72 73 74def prune_input_vocab( 75 model: torch.nn.Module, 76 token_map: Dict[int, int], 77 imput_layer_name: str = "tok_embeddings", 78) -> torch.nn.Module: 79 """Prune the model input embedding layer while keeping the tokens in the token map. 80 81 Note: Pruning is performed in-place. 82 83 Args: 84 model: The model to prune. 85 token_map: A dictionary mapping from new token ids to the old token ids to preserve. 86 e.g. {0: 221, 1: 1325, 2: 1542, 3: 1728, 4: 18243} 87 imput_layer_name: name of the input embedding layer to prune 88 89 Returns: 90 The pruned model. 91 """ 92 assert hasattr( 93 model, imput_layer_name 94 ), f"Model does not have {imput_layer_name} layer" 95 input_layer = getattr(model, imput_layer_name) 96 assert isinstance( 97 input_layer, torch.nn.Embedding 98 ), "Input layer is not an Embedding layer" 99 original_shape = input_layer.weight.shape 100 num_pruned_tokens = len(token_map) 101 weight_dtype = input_layer.weight.dtype 102 pruned_layer = torch.nn.Embedding(num_pruned_tokens, original_shape[1]) 103 pruned_layer.to(dtype=weight_dtype) 104 pruned_layer_weights = np.zeros(pruned_layer.weight.shape, dtype=np.float32) 105 for i, token_id in token_map.items(): 106 # Copy the weights from the original layer to the pruned layer 107 pruned_wt = input_layer.weight[token_id].detach() 108 if weight_dtype == torch.bfloat16: 109 pruned_wt = pruned_wt.float() 110 pruned_layer_weights[i] = pruned_wt.numpy() 111 with torch.no_grad(): 112 pruned_layer.weight.copy_( 113 torch.tensor(pruned_layer_weights, dtype=weight_dtype) 114 ) 115 116 # Replace the original layer with the pruned layer 117 setattr(model, imput_layer_name, pruned_layer) 118 119 return model 120