1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport re 8*523fa7a6SAndroid Build Coastguard Workerfrom typing import Any, Dict, Mapping 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerdef convert_model_state_dict( 12*523fa7a6SAndroid Build Coastguard Worker state_dict: Dict[str, Any], key_map: Mapping[str, str] 13*523fa7a6SAndroid Build Coastguard Worker) -> Dict[str, Any]: 14*523fa7a6SAndroid Build Coastguard Worker """Convert a model state dictionary to fairseq2. 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Worker :param state_dict: 17*523fa7a6SAndroid Build Coastguard Worker The original model state dictionary. 18*523fa7a6SAndroid Build Coastguard Worker :param key_map: 19*523fa7a6SAndroid Build Coastguard Worker A map of regex patterns to fairseq2 model keys. 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker :returns: 22*523fa7a6SAndroid Build Coastguard Worker A converted model state dictionary that is compatible with fairseq2. 23*523fa7a6SAndroid Build Coastguard Worker """ 24*523fa7a6SAndroid Build Coastguard Worker new_state_dict = {} 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker def get_new_key(old_key: str) -> str: 27*523fa7a6SAndroid Build Coastguard Worker for old_pattern, replacement in key_map.items(): 28*523fa7a6SAndroid Build Coastguard Worker if (new_key := re.sub(old_pattern, replacement, old_key)) != old_key: 29*523fa7a6SAndroid Build Coastguard Worker return new_key 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker return old_key 32*523fa7a6SAndroid Build Coastguard Worker 33*523fa7a6SAndroid Build Coastguard Worker # Convert module keys from fairseq to fairseq2. 34*523fa7a6SAndroid Build Coastguard Worker for old_key in state_dict.keys(): 35*523fa7a6SAndroid Build Coastguard Worker new_key = get_new_key(old_key) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker new_state_dict[new_key] = state_dict[old_key] 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker return new_state_dict 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Workerdef convert_to_llama_checkpoint(checkpoint: Dict[str, Any]) -> Dict[str, Any]: 43*523fa7a6SAndroid Build Coastguard Worker """Convert a fairseq2 LLaMA checkpoint to the reference format.""" 44*523fa7a6SAndroid Build Coastguard Worker # state_dict = checkpoint["model"] 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker key_map = { 47*523fa7a6SAndroid Build Coastguard Worker # fmt: off 48*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", 49*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", 50*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", 51*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).self_attn.output_proj.": r"layers.\1.attention.wo.", 52*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).self_attn_layer_norm.": r"layers.\1.attention_norm.", 53*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).ffn.gate_proj.": r"layers.\1.feed_forward.w1.", 54*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).ffn.output_proj.": r"layers.\1.feed_forward.w2.", 55*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).ffn.inner_proj.": r"layers.\1.feed_forward.w3.", 56*523fa7a6SAndroid Build Coastguard Worker r"decoder.layers.([0-9]+).ffn_layer_norm.": r"layers.\1.ffn_norm.", 57*523fa7a6SAndroid Build Coastguard Worker r"decoder.layer_norm.": r"norm.", 58*523fa7a6SAndroid Build Coastguard Worker r"decoder_frontend.embed.": r"tok_embeddings.", 59*523fa7a6SAndroid Build Coastguard Worker r"final_proj.": r"output.", 60*523fa7a6SAndroid Build Coastguard Worker # fmt: on 61*523fa7a6SAndroid Build Coastguard Worker } 62*523fa7a6SAndroid Build Coastguard Worker 63*523fa7a6SAndroid Build Coastguard Worker return convert_model_state_dict(checkpoint, key_map) 64