1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from abc import ABC, abstractmethod 8from typing import List, Optional 9 10import torch 11 12from executorch.extension.llm.tokenizer.utils import get_tokenizer 13 14 15def sample_top_p(probs, p): 16 """ 17 Perform top-p (nucleus) sampling on a probability distribution. 18 19 Args: 20 probs (torch.Tensor): Probability distribution tensor. 21 p (float): Probability threshold for top-p sampling. 22 23 Returns: 24 torch.Tensor: Sampled token indices. 25 26 Note: 27 Top-p sampling selects the smallest set of tokens whose cumulative probability mass 28 exceeds the threshold p. The distribution is re-normalized based on the selected tokens. 29 """ 30 probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 31 probs_sum = torch.cumsum(probs_sort, dim=-1) 32 mask = probs_sum - probs_sort > p 33 probs_sort[mask] = 0.0 34 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 35 next_token = torch.multinomial(probs_sort, num_samples=1) 36 next_token = torch.gather(probs_idx, -1, next_token) 37 return next_token 38 39 40def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: 41 if temperature > 0: 42 probs = torch.softmax(logits / temperature, dim=-1) 43 return sample_top_p(probs, top_p).item() 44 # Pyre-ignore[7]: Incompatible return type [7]: Expected `int` but got `Union[bool, float, int]` 45 return torch.argmax(logits, dim=-1).item() 46 47 48class LlamaRunner(ABC): 49 def __init__( 50 self, 51 tokenizer_path: str, 52 max_seq_len: int, 53 max_batch_size: int, 54 use_kv_cache: bool, 55 vocab_size: int, 56 device: str = "cpu", 57 ): 58 """ 59 Constructor. 60 61 Args: 62 tokenizer_path: path to tokenizer.model file. 63 max_seq_len: max length of the output sequence, after which the output will be clipped. 64 max_batch_size: max batch size. 65 use_kv_cache: whether to use a KV cache. 66 vocab_size: number of items in the vocab. 67 device: device to run the runner on. 68 """ 69 self.max_seq_len = max_seq_len 70 self.max_batch_size = max_batch_size 71 self.use_kv_cache = use_kv_cache 72 self.tokenizer = get_tokenizer(tokenizer_path) 73 self.device = device 74 assert vocab_size == self.tokenizer.n_words 75 76 @abstractmethod 77 def forward( 78 self, 79 tokens: torch.Tensor, 80 input_pos: Optional[torch.Tensor] = None, 81 ) -> torch.Tensor: 82 pass 83 84 def generate( # noqa: C901 85 self, 86 prompt_tokens: List[int], 87 max_seq_len: int, 88 temperature: float = 0.8, 89 top_p: float = 0.9, 90 echo: bool = False, 91 pos_base: int = 0, 92 ) -> List[int]: 93 # Prefill 94 logits = self.forward( 95 tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), 96 input_pos=( 97 torch.tensor([pos_base], dtype=torch.long, device=self.device) 98 if self.use_kv_cache 99 else None 100 ), 101 ) 102 103 current_token = next_token(logits, temperature, top_p) 104 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) 105 tokens = prompt_tokens + [current_token] 106 107 while len(tokens) < max_seq_len: 108 if self.use_kv_cache: 109 logits = self.forward( 110 tokens=torch.tensor( 111 [[current_token]], dtype=torch.long, device=self.device 112 ), 113 input_pos=torch.tensor( 114 [pos_base + len(tokens) - 1], 115 dtype=torch.long, 116 device=self.device, 117 ), 118 ) 119 else: 120 logits = self.forward( 121 tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), 122 ) 123 124 # If the logits aren't already clipped to only contain the last logit, clip them. 125 current_token = next_token(logits, temperature, top_p) 126 tokens.append(current_token) 127 128 if current_token == self.tokenizer.eos_id or ( 129 hasattr(self.tokenizer, "stop_tokens") 130 and current_token in self.tokenizer.stop_tokens 131 ): 132 break 133 134 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) 135 print("\n") 136 137 return tokens if echo else tokens[len(prompt_tokens) :] 138 139 def text_completion( 140 self, 141 prompt: str, 142 temperature: float = 0.6, 143 top_p: float = 0.9, 144 echo: bool = False, 145 ) -> List[int]: 146 """ 147 Perform text completion for a prompt using the language model. 148 149 Args: 150 prompt (str): Text prompt for completion. 151 temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 152 top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 153 echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. 154 155 Returns: 156 Generated list of tokens. 157 158 Note: 159 This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. 160 """ 161 return self.generate( 162 prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), 163 max_seq_len=self.max_seq_len, 164 temperature=temperature, 165 top_p=top_p, 166 echo=echo, 167 ) 168 169 def chat_completion( 170 self, 171 temperature: float = 0.6, 172 top_p: float = 0.9, 173 ) -> List[int]: 174 """ 175 Perform multi-turn chat with the language model. 176 177 Args: 178 prompt (str): Text prompt for completion. 179 temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. 180 top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. 181 echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. 182 183 Returns: 184 Generated list of tokens. 185 186 Note: 187 This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. 188 """ 189 exit_prompt = "exit" 190 tokens = [] 191 prompt = input("Me: ") 192 while prompt and prompt != exit_prompt: 193 print("LLM: ", end="", flush=True) 194 new_tokens = self.generate( 195 prompt_tokens=self.tokenizer.encode( 196 self._format_prompt(prompt), bos=True, eos=False 197 ), 198 max_seq_len=self.max_seq_len, 199 temperature=temperature, 200 top_p=top_p, 201 echo=True, 202 pos_base=len(tokens) - 1 if len(tokens) > 0 else 0, 203 ) 204 tokens.extend(new_tokens) 205 prompt = input("Me: ") 206 return tokens 207 208 def _format_prompt(self, prompt: str) -> str: 209 return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> 210 211You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> 212 213{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>""" 214