xref: /aosp_15_r20/external/executorch/examples/models/llama/runner/generation.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from abc import ABC, abstractmethod
8from typing import List, Optional
9
10import torch
11
12from executorch.extension.llm.tokenizer.utils import get_tokenizer
13
14
15def sample_top_p(probs, p):
16    """
17    Perform top-p (nucleus) sampling on a probability distribution.
18
19    Args:
20        probs (torch.Tensor): Probability distribution tensor.
21        p (float): Probability threshold for top-p sampling.
22
23    Returns:
24        torch.Tensor: Sampled token indices.
25
26    Note:
27        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
28        exceeds the threshold p. The distribution is re-normalized based on the selected tokens.
29    """
30    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
31    probs_sum = torch.cumsum(probs_sort, dim=-1)
32    mask = probs_sum - probs_sort > p
33    probs_sort[mask] = 0.0
34    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
35    next_token = torch.multinomial(probs_sort, num_samples=1)
36    next_token = torch.gather(probs_idx, -1, next_token)
37    return next_token
38
39
40def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
41    if temperature > 0:
42        probs = torch.softmax(logits / temperature, dim=-1)
43        return sample_top_p(probs, top_p).item()
44    # Pyre-ignore[7]: Incompatible return type [7]: Expected `int` but got `Union[bool, float, int]`
45    return torch.argmax(logits, dim=-1).item()
46
47
48class LlamaRunner(ABC):
49    def __init__(
50        self,
51        tokenizer_path: str,
52        max_seq_len: int,
53        max_batch_size: int,
54        use_kv_cache: bool,
55        vocab_size: int,
56        device: str = "cpu",
57    ):
58        """
59        Constructor.
60
61        Args:
62        tokenizer_path: path to tokenizer.model file.
63        max_seq_len: max length of the output sequence, after which the output will be clipped.
64        max_batch_size: max batch size.
65        use_kv_cache: whether to use a KV cache.
66        vocab_size: number of items in the vocab.
67        device: device to run the runner on.
68        """
69        self.max_seq_len = max_seq_len
70        self.max_batch_size = max_batch_size
71        self.use_kv_cache = use_kv_cache
72        self.tokenizer = get_tokenizer(tokenizer_path)
73        self.device = device
74        assert vocab_size == self.tokenizer.n_words
75
76    @abstractmethod
77    def forward(
78        self,
79        tokens: torch.Tensor,
80        input_pos: Optional[torch.Tensor] = None,
81    ) -> torch.Tensor:
82        pass
83
84    def generate(  # noqa: C901
85        self,
86        prompt_tokens: List[int],
87        max_seq_len: int,
88        temperature: float = 0.8,
89        top_p: float = 0.9,
90        echo: bool = False,
91        pos_base: int = 0,
92    ) -> List[int]:
93        # Prefill
94        logits = self.forward(
95            tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
96            input_pos=(
97                torch.tensor([pos_base], dtype=torch.long, device=self.device)
98                if self.use_kv_cache
99                else None
100            ),
101        )
102
103        current_token = next_token(logits, temperature, top_p)
104        print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
105        tokens = prompt_tokens + [current_token]
106
107        while len(tokens) < max_seq_len:
108            if self.use_kv_cache:
109                logits = self.forward(
110                    tokens=torch.tensor(
111                        [[current_token]], dtype=torch.long, device=self.device
112                    ),
113                    input_pos=torch.tensor(
114                        [pos_base + len(tokens) - 1],
115                        dtype=torch.long,
116                        device=self.device,
117                    ),
118                )
119            else:
120                logits = self.forward(
121                    tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
122                )
123
124            # If the logits aren't already clipped to only contain the last logit, clip them.
125            current_token = next_token(logits, temperature, top_p)
126            tokens.append(current_token)
127
128            if current_token == self.tokenizer.eos_id or (
129                hasattr(self.tokenizer, "stop_tokens")
130                and current_token in self.tokenizer.stop_tokens
131            ):
132                break
133
134            print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
135        print("\n")
136
137        return tokens if echo else tokens[len(prompt_tokens) :]
138
139    def text_completion(
140        self,
141        prompt: str,
142        temperature: float = 0.6,
143        top_p: float = 0.9,
144        echo: bool = False,
145    ) -> List[int]:
146        """
147        Perform text completion for a prompt using the language model.
148
149        Args:
150            prompt (str): Text prompt for completion.
151            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
152            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
153            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
154
155        Returns:
156            Generated list of tokens.
157
158        Note:
159            This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
160        """
161        return self.generate(
162            prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
163            max_seq_len=self.max_seq_len,
164            temperature=temperature,
165            top_p=top_p,
166            echo=echo,
167        )
168
169    def chat_completion(
170        self,
171        temperature: float = 0.6,
172        top_p: float = 0.9,
173    ) -> List[int]:
174        """
175        Perform multi-turn chat with the language model.
176
177            Args:
178                prompt (str): Text prompt for completion.
179                temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
180                top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
181                echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
182
183            Returns:
184                Generated list of tokens.
185
186            Note:
187                This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
188        """
189        exit_prompt = "exit"
190        tokens = []
191        prompt = input("Me: ")
192        while prompt and prompt != exit_prompt:
193            print("LLM: ", end="", flush=True)
194            new_tokens = self.generate(
195                prompt_tokens=self.tokenizer.encode(
196                    self._format_prompt(prompt), bos=True, eos=False
197                ),
198                max_seq_len=self.max_seq_len,
199                temperature=temperature,
200                top_p=top_p,
201                echo=True,
202                pos_base=len(tokens) - 1 if len(tokens) > 0 else 0,
203            )
204            tokens.extend(new_tokens)
205            prompt = input("Me: ")
206        return tokens
207
208    def _format_prompt(self, prompt: str) -> str:
209        return f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
210
211You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
212
213{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
214