xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/eager.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8# Script to run phi-3-mini model in eager mode.
9
10import argparse
11import time
12
13import torch
14
15from transformers import AutoTokenizer, Phi3ForCausalLM
16
17from .phi_3_mini import Phi3Mini
18
19end_of_text_token = 32000
20
21
22def _generate_token(args, model, prompt_tokens):
23    current_token = 0
24    generated_tokens = []
25
26    print("Generating tokens:", end="", flush=True)
27
28    while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
29        outputs = model.forward(input_ids=prompt_tokens)
30        current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item()
31        print(f" {current_token}", end="", flush=True)
32        generated_tokens.append(current_token)
33        prompt_tokens = torch.cat(
34            [prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1
35        )
36
37    print("", flush=True)
38
39    return generated_tokens
40
41
42def _generate_token_with_kv_cache(args, model, prompt_tokens):
43    print("Generating tokens:", end="", flush=True)
44
45    model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1])
46    result = model.forward(input_ids=prompt_tokens)
47
48    current_token = torch.argmax(result, dim=-1).item()
49    print(f" {current_token}", end="", flush=True)
50    generated_tokens = [current_token]
51
52    while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
53        result = model.forward(
54            input_ids=torch.tensor([[current_token]], dtype=torch.long),
55        )
56        current_token = torch.argmax(result, dim=-1).item()
57        print(f" {current_token}", end="", flush=True)
58        generated_tokens.append(current_token)
59
60    print("", flush=True)
61
62    return generated_tokens
63
64
65def main(args):
66    seed = 42
67    torch.manual_seed(seed)
68    model_name = "microsoft/Phi-3-mini-4k-instruct"
69    model = Phi3ForCausalLM.from_pretrained(model_name)
70    tokenizer = AutoTokenizer.from_pretrained(model_name)
71
72    tokens = tokenizer.encode(args.prompt, return_tensors="pt")
73
74    start = time.time()
75    generated_tokens = (
76        _generate_token_with_kv_cache(args, model, tokens)
77        if args.use_kv_cache
78        else _generate_token(args, model, tokens)
79    )
80    end = time.time()
81
82    print(
83        "Generated response: \n {}".format(
84            tokenizer.decode(
85                generated_tokens,
86                skip_special_tokens=True,
87                clean_up_tokenization_spaces=False,
88            )
89        ),
90        flush=True,
91    )
92    print(f"Time spent: {end - start}", flush=True)
93
94
95if __name__ == "__main__":
96    parser = argparse.ArgumentParser()
97    parser.add_argument(
98        "-s",
99        "--seq_len",
100        type=int,
101        default=128,
102        help="Maximum number of tokens to generate",
103    )
104    parser.add_argument(
105        "-kv",
106        "--use_kv_cache",
107        default=False,
108        action="store_true",
109        help="Whether or not to use KV cache",
110    )
111    parser.add_argument(
112        "-p",
113        "--prompt",
114        type=str,
115        default="Tell me a story",
116        help="Prompt as input for the model",
117    )
118    main(parser.parse_args())
119