xref: /aosp_15_r20/external/executorch/examples/models/llama/fairseq2.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import re
8from typing import Any, Dict, Mapping
9
10
11def convert_model_state_dict(
12    state_dict: Dict[str, Any], key_map: Mapping[str, str]
13) -> Dict[str, Any]:
14    """Convert a model state dictionary to fairseq2.
15
16    :param state_dict:
17        The original model state dictionary.
18    :param key_map:
19        A map of regex patterns to fairseq2 model keys.
20
21    :returns:
22        A converted model state dictionary that is compatible with fairseq2.
23    """
24    new_state_dict = {}
25
26    def get_new_key(old_key: str) -> str:
27        for old_pattern, replacement in key_map.items():
28            if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
29                return new_key
30
31        return old_key
32
33    # Convert module keys from fairseq to fairseq2.
34    for old_key in state_dict.keys():
35        new_key = get_new_key(old_key)
36
37        new_state_dict[new_key] = state_dict[old_key]
38
39    return new_state_dict
40
41
42def convert_to_llama_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
43    """Convert a fairseq2 LLaMA checkpoint to the reference format."""
44    # state_dict = checkpoint["model"]
45
46    key_map = {
47        # fmt: off
48        r"decoder.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
49        r"decoder.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
50        r"decoder.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
51        r"decoder.layers.([0-9]+).self_attn.output_proj.": r"layers.\1.attention.wo.",
52        r"decoder.layers.([0-9]+).self_attn_layer_norm.": r"layers.\1.attention_norm.",
53        r"decoder.layers.([0-9]+).ffn.gate_proj.": r"layers.\1.feed_forward.w1.",
54        r"decoder.layers.([0-9]+).ffn.output_proj.": r"layers.\1.feed_forward.w2.",
55        r"decoder.layers.([0-9]+).ffn.inner_proj.": r"layers.\1.feed_forward.w3.",
56        r"decoder.layers.([0-9]+).ffn_layer_norm.": r"layers.\1.ffn_norm.",
57        r"decoder.layer_norm.": r"norm.",
58        r"decoder_frontend.embed.": r"tok_embeddings.",
59        r"final_proj.": r"output.",
60        # fmt: on
61    }
62
63    return convert_model_state_dict(checkpoint, key_map)
64