1# coding=utf-8 2# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3# 4# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5# and OPT implementations in this library. It has been modified from its 6# original forms to accommodate minor architectural differences compared 7# to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8# 9# Licensed under the Apache License, Version 2.0 (the "License"); 10# you may not use this file except in compliance with the License. 11# You may obtain a copy of the License at 12# 13# http://www.apache.org/licenses/LICENSE-2.0 14# 15# Unless required by applicable law or agreed to in writing, software 16# distributed under the License is distributed on an "AS IS" BASIS, 17# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18# See the License for the specific language governing permissions and 19# limitations under the License. 20from contextlib import nullcontext 21 22from models.llm_models.configuration_base import BaseConfig 23 24 25# flake8: noqa: C901 26 27 28class LlamaConfig(BaseConfig): 29 def __init__( 30 self, 31 vocab_size=None, 32 hidden_size=None, 33 intermediate_size=None, 34 num_hidden_layers=None, 35 num_attention_heads=None, 36 max_position_embeddings=None, 37 norm="RMSNorm", 38 position_embedding="rope", 39 norm_eps=1e-6, 40 pad_token_id=0, 41 bos_token_id=1, 42 eos_token_id=2, 43 unk_token_id=0, 44 use_stable_embedding=False, 45 tie_word_embeddings=False, 46 combine_qkv=False, 47 response_handler=None, 48 **kwargs, 49 ): 50 super().__init__() 51 52 self.model_type = "llama" 53 self.vocab_size = vocab_size 54 if self.vocab_size is None: 55 raise KeyError("vocab_size is required but missing from config.json") 56 self.hidden_size = hidden_size 57 if self.hidden_size is None: 58 raise KeyError("hidden_size is required but missing from config.json") 59 self.intermediate_size = intermediate_size 60 if self.intermediate_size is None: 61 raise KeyError("intermediate_size is required but missing from config.json") 62 self.num_hidden_layers = num_hidden_layers 63 if self.num_hidden_layers is None: 64 raise KeyError("num_hidden_layers is required but missing from config.json") 65 self.num_attention_heads = num_attention_heads 66 if self.num_attention_heads is None: 67 raise KeyError( 68 "num_attention_heads is required but missing from config.json" 69 ) 70 self.num_key_value_heads = kwargs.pop( 71 "num_key_value_heads", self.num_attention_heads 72 ) 73 if self.num_attention_heads % self.num_key_value_heads != 0: 74 raise RuntimeError( 75 f"num_attention_heads ({self.num_attention_heads}) must be exactly " 76 f"divisible by num_key_value_heads ({self.num_key_value_heads})" 77 ) 78 if norm not in ["RMSNorm", "LayerNorm"]: 79 raise ValueError("norm must be one of: RMSNorm (default) or LayerNorm") 80 self.norm = norm 81 self.norm_eps = kwargs.pop("rms_norm_eps", norm_eps) 82 self.bos_token_id = bos_token_id 83 self.eos_token_id = eos_token_id 84 self.pad_token_id = pad_token_id 85 self.unk_token_id = unk_token_id 86 87 if position_embedding not in ["rope", "alibi"]: 88 raise ValueError("Positional embedding must be one of: rope, alibi") 89 self.position_embedding = position_embedding 90 self.ntk_scaling_factor = kwargs.pop("ntk_scaling_factor", 1.0) 91 if self.ntk_scaling_factor != 1.0 and self.position_embedding != "rope": 92 raise KeyError("ntk_scaling_factor is strictly for position_embedding=rope") 93 self.max_position_embeddings = max_position_embeddings 94 if self.max_position_embeddings is None and self.position_embedding == "rope": 95 raise KeyError( 96 "max_position_embeddings is required for position_embedding=rope but missing from config.json" 97 ) 98 99 self.use_stable_embedding = use_stable_embedding 100 self.tie_word_embeddings = tie_word_embeddings 101 self.combine_qkv = combine_qkv 102 103 self.tokenizer = kwargs.pop("tokenizer", self.tokenizer) 104 105 if response_handler is None: 106 response_handler = nullcontext() 107 if kwargs.pop("verbose", True): 108 self.print_config(response_handler) 109 110 def print_config(self, response_handler): 111 with response_handler: 112 print(f"{self.model_type} config:") 113 print(f"Hidden size: {self.hidden_size}") 114 print(f"Intermediate size: {self.intermediate_size}") 115 print(f"Num layers: {self.num_hidden_layers}") 116 print(f"Num attention heads: {self.num_attention_heads}") 117 print(f"Num KV heads: {self.num_key_value_heads}") 118 print(f"Positional embedding: {self.position_embedding}") 119 if self.position_embedding == "rope": 120 print(f"Max pos emb: {self.max_position_embeddings}") 121 if self.ntk_scaling_factor != 1.0: 122 print(f"NTK scaling factor: {self.ntk_scaling_factor}") 123 print(f"Norm type: {self.norm}") 124 print(f"Norm epsilon: {self.norm_eps}") 125 print(f"BOS token id: {self.bos_token_id}") 126 print(f"EOS token id: {self.eos_token_id}") 127 print(f"PAD token id: {self.pad_token_id}") 128 print(f"UNK token id: {self.unk_token_id}") 129 print(f"Vocab size: {self.vocab_size}") 130 print(f"Use stable embedding: {self.use_stable_embedding}") 131 print(f"Tie word embeddings: {self.tie_word_embeddings}") 132 print(f"Combine QKV: {self.combine_qkv}") 133 if self.tokenizer != "default": 134 print(f"Tokenizer: {self.tokenizer}") 135 print() 136