xref: /aosp_15_r20/external/executorch/examples/models/llama/fairseq2.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport re
8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, Mapping
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerdef convert_model_state_dict(
12*523fa7a6SAndroid Build Coastguard Worker    state_dict: Dict[str, Any], key_map: Mapping[str, str]
13*523fa7a6SAndroid Build Coastguard Worker) -> Dict[str, Any]:
14*523fa7a6SAndroid Build Coastguard Worker    """Convert a model state dictionary to fairseq2.
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Worker    :param state_dict:
17*523fa7a6SAndroid Build Coastguard Worker        The original model state dictionary.
18*523fa7a6SAndroid Build Coastguard Worker    :param key_map:
19*523fa7a6SAndroid Build Coastguard Worker        A map of regex patterns to fairseq2 model keys.
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker    :returns:
22*523fa7a6SAndroid Build Coastguard Worker        A converted model state dictionary that is compatible with fairseq2.
23*523fa7a6SAndroid Build Coastguard Worker    """
24*523fa7a6SAndroid Build Coastguard Worker    new_state_dict = {}
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker    def get_new_key(old_key: str) -> str:
27*523fa7a6SAndroid Build Coastguard Worker        for old_pattern, replacement in key_map.items():
28*523fa7a6SAndroid Build Coastguard Worker            if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key:
29*523fa7a6SAndroid Build Coastguard Worker                return new_key
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker        return old_key
32*523fa7a6SAndroid Build Coastguard Worker
33*523fa7a6SAndroid Build Coastguard Worker    # Convert module keys from fairseq to fairseq2.
34*523fa7a6SAndroid Build Coastguard Worker    for old_key in state_dict.keys():
35*523fa7a6SAndroid Build Coastguard Worker        new_key = get_new_key(old_key)
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker        new_state_dict[new_key] = state_dict[old_key]
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    return new_state_dict
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Workerdef convert_to_llama_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]:
43*523fa7a6SAndroid Build Coastguard Worker    """Convert a fairseq2 LLaMA checkpoint to the reference format."""
44*523fa7a6SAndroid Build Coastguard Worker    # state_dict = checkpoint["model"]
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker    key_map = {
47*523fa7a6SAndroid Build Coastguard Worker        # fmt: off
48*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
49*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
50*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
51*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).self_attn.output_proj.": r"layers.\1.attention.wo.",
52*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).self_attn_layer_norm.": r"layers.\1.attention_norm.",
53*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).ffn.gate_proj.": r"layers.\1.feed_forward.w1.",
54*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).ffn.output_proj.": r"layers.\1.feed_forward.w2.",
55*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).ffn.inner_proj.": r"layers.\1.feed_forward.w3.",
56*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layers.([0-9]+).ffn_layer_norm.": r"layers.\1.ffn_norm.",
57*523fa7a6SAndroid Build Coastguard Worker        r"decoder.layer_norm.": r"norm.",
58*523fa7a6SAndroid Build Coastguard Worker        r"decoder_frontend.embed.": r"tok_embeddings.",
59*523fa7a6SAndroid Build Coastguard Worker        r"final_proj.": r"output.",
60*523fa7a6SAndroid Build Coastguard Worker        # fmt: on
61*523fa7a6SAndroid Build Coastguard Worker    }
62*523fa7a6SAndroid Build Coastguard Worker
63*523fa7a6SAndroid Build Coastguard Worker    return convert_model_state_dict(checkpoint, key_map)
64