1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import re 8from typing import Any, Dict, Mapping 9 10 11def convert_model_state_dict( 12 state_dict: Dict[str, Any], key_map: Mapping[str, str] 13) -> Dict[str, Any]: 14 """Convert a model state dictionary to fairseq2. 15 16 :param state_dict: 17 The original model state dictionary. 18 :param key_map: 19 A map of regex patterns to fairseq2 model keys. 20 21 :returns: 22 A converted model state dictionary that is compatible with fairseq2. 23 """ 24 new_state_dict = {} 25 26 def get_new_key(old_key: str) -> str: 27 for old_pattern, replacement in key_map.items(): 28 if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: 29 return new_key 30 31 return old_key 32 33 # Convert module keys from fairseq to fairseq2. 34 for old_key in state_dict.keys(): 35 new_key = get_new_key(old_key) 36 37 new_state_dict[new_key] = state_dict[old_key] 38 39 return new_state_dict 40 41 42def convert_to_llama_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: 43 """Convert a fairseq2 LLaMA checkpoint to the reference format.""" 44 # state_dict = checkpoint["model"] 45 46 key_map = { 47 # fmt: off 48 r"decoder.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", 49 r"decoder.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", 50 r"decoder.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", 51 r"decoder.layers.([0-9]+).self_attn.output_proj.": r"layers.\1.attention.wo.", 52 r"decoder.layers.([0-9]+).self_attn_layer_norm.": r"layers.\1.attention_norm.", 53 r"decoder.layers.([0-9]+).ffn.gate_proj.": r"layers.\1.feed_forward.w1.", 54 r"decoder.layers.([0-9]+).ffn.output_proj.": r"layers.\1.feed_forward.w2.", 55 r"decoder.layers.([0-9]+).ffn.inner_proj.": r"layers.\1.feed_forward.w3.", 56 r"decoder.layers.([0-9]+).ffn_layer_norm.": r"layers.\1.ffn_norm.", 57 r"decoder.layer_norm.": r"norm.", 58 r"decoder_frontend.embed.": r"tok_embeddings.", 59 r"final_proj.": r"output.", 60 # fmt: on 61 } 62 63 return convert_model_state_dict(checkpoint, key_map) 64