xref: /aosp_15_r20/external/executorch/examples/mediatek/models/llm_models/configuration_llama.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# coding=utf-8
2# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3#
4# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5# and OPT implementations in this library. It has been modified from its
6# original forms to accommodate minor architectural differences compared
7# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8#
9# Licensed under the Apache License, Version 2.0 (the "License");
10# you may not use this file except in compliance with the License.
11# You may obtain a copy of the License at
12#
13#     http://www.apache.org/licenses/LICENSE-2.0
14#
15# Unless required by applicable law or agreed to in writing, software
16# distributed under the License is distributed on an "AS IS" BASIS,
17# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18# See the License for the specific language governing permissions and
19# limitations under the License.
20from contextlib import nullcontext
21
22from models.llm_models.configuration_base import BaseConfig
23
24
25# flake8: noqa: C901
26
27
28class LlamaConfig(BaseConfig):
29    def __init__(
30        self,
31        vocab_size=None,
32        hidden_size=None,
33        intermediate_size=None,
34        num_hidden_layers=None,
35        num_attention_heads=None,
36        max_position_embeddings=None,
37        norm="RMSNorm",
38        position_embedding="rope",
39        norm_eps=1e-6,
40        pad_token_id=0,
41        bos_token_id=1,
42        eos_token_id=2,
43        unk_token_id=0,
44        use_stable_embedding=False,
45        tie_word_embeddings=False,
46        combine_qkv=False,
47        response_handler=None,
48        **kwargs,
49    ):
50        super().__init__()
51
52        self.model_type = "llama"
53        self.vocab_size = vocab_size
54        if self.vocab_size is None:
55            raise KeyError("vocab_size is required but missing from config.json")
56        self.hidden_size = hidden_size
57        if self.hidden_size is None:
58            raise KeyError("hidden_size is required but missing from config.json")
59        self.intermediate_size = intermediate_size
60        if self.intermediate_size is None:
61            raise KeyError("intermediate_size is required but missing from config.json")
62        self.num_hidden_layers = num_hidden_layers
63        if self.num_hidden_layers is None:
64            raise KeyError("num_hidden_layers is required but missing from config.json")
65        self.num_attention_heads = num_attention_heads
66        if self.num_attention_heads is None:
67            raise KeyError(
68                "num_attention_heads is required but missing from config.json"
69            )
70        self.num_key_value_heads = kwargs.pop(
71            "num_key_value_heads", self.num_attention_heads
72        )
73        if self.num_attention_heads % self.num_key_value_heads != 0:
74            raise RuntimeError(
75                f"num_attention_heads ({self.num_attention_heads}) must be exactly "
76                f"divisible by num_key_value_heads ({self.num_key_value_heads})"
77            )
78        if norm not in ["RMSNorm", "LayerNorm"]:
79            raise ValueError("norm must be one of: RMSNorm (default) or LayerNorm")
80        self.norm = norm
81        self.norm_eps = kwargs.pop("rms_norm_eps", norm_eps)
82        self.bos_token_id = bos_token_id
83        self.eos_token_id = eos_token_id
84        self.pad_token_id = pad_token_id
85        self.unk_token_id = unk_token_id
86
87        if position_embedding not in ["rope", "alibi"]:
88            raise ValueError("Positional embedding must be one of: rope, alibi")
89        self.position_embedding = position_embedding
90        self.ntk_scaling_factor = kwargs.pop("ntk_scaling_factor", 1.0)
91        if self.ntk_scaling_factor != 1.0 and self.position_embedding != "rope":
92            raise KeyError("ntk_scaling_factor is strictly for position_embedding=rope")
93        self.max_position_embeddings = max_position_embeddings
94        if self.max_position_embeddings is None and self.position_embedding == "rope":
95            raise KeyError(
96                "max_position_embeddings is required for position_embedding=rope but missing from config.json"
97            )
98
99        self.use_stable_embedding = use_stable_embedding
100        self.tie_word_embeddings = tie_word_embeddings
101        self.combine_qkv = combine_qkv
102
103        self.tokenizer = kwargs.pop("tokenizer", self.tokenizer)
104
105        if response_handler is None:
106            response_handler = nullcontext()
107        if kwargs.pop("verbose", True):
108            self.print_config(response_handler)
109
110    def print_config(self, response_handler):
111        with response_handler:
112            print(f"{self.model_type} config:")
113            print(f"Hidden size:          {self.hidden_size}")
114            print(f"Intermediate size:    {self.intermediate_size}")
115            print(f"Num layers:           {self.num_hidden_layers}")
116            print(f"Num attention heads:  {self.num_attention_heads}")
117            print(f"Num KV heads:         {self.num_key_value_heads}")
118            print(f"Positional embedding: {self.position_embedding}")
119            if self.position_embedding == "rope":
120                print(f"Max pos emb:          {self.max_position_embeddings}")
121                if self.ntk_scaling_factor != 1.0:
122                    print(f"NTK scaling factor:   {self.ntk_scaling_factor}")
123            print(f"Norm type:            {self.norm}")
124            print(f"Norm epsilon:         {self.norm_eps}")
125            print(f"BOS token id:         {self.bos_token_id}")
126            print(f"EOS token id:         {self.eos_token_id}")
127            print(f"PAD token id:         {self.pad_token_id}")
128            print(f"UNK token id:         {self.unk_token_id}")
129            print(f"Vocab size:           {self.vocab_size}")
130            print(f"Use stable embedding: {self.use_stable_embedding}")
131            print(f"Tie word embeddings:  {self.tie_word_embeddings}")
132            print(f"Combine QKV:          {self.combine_qkv}")
133            if self.tokenizer != "default":
134                print(f"Tokenizer:            {self.tokenizer}")
135            print()
136