xref: /aosp_15_r20/external/executorch/examples/models/phi-3-mini/eager.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker
8*523fa7a6SAndroid Build Coastguard Worker# Script to run phi-3-mini model in eager mode.
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerimport argparse
11*523fa7a6SAndroid Build Coastguard Workerimport time
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Workerimport torch
14*523fa7a6SAndroid Build Coastguard Worker
15*523fa7a6SAndroid Build Coastguard Workerfrom transformers import AutoTokenizer, Phi3ForCausalLM
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Workerfrom .phi_3_mini import Phi3Mini
18*523fa7a6SAndroid Build Coastguard Worker
19*523fa7a6SAndroid Build Coastguard Workerend_of_text_token = 32000
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker
22*523fa7a6SAndroid Build Coastguard Workerdef _generate_token(args, model, prompt_tokens):
23*523fa7a6SAndroid Build Coastguard Worker    current_token = 0
24*523fa7a6SAndroid Build Coastguard Worker    generated_tokens = []
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Worker    print("Generating tokens:", end="", flush=True)
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker    while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
29*523fa7a6SAndroid Build Coastguard Worker        outputs = model.forward(input_ids=prompt_tokens)
30*523fa7a6SAndroid Build Coastguard Worker        current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item()
31*523fa7a6SAndroid Build Coastguard Worker        print(f" {current_token}", end="", flush=True)
32*523fa7a6SAndroid Build Coastguard Worker        generated_tokens.append(current_token)
33*523fa7a6SAndroid Build Coastguard Worker        prompt_tokens = torch.cat(
34*523fa7a6SAndroid Build Coastguard Worker            [prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1
35*523fa7a6SAndroid Build Coastguard Worker        )
36*523fa7a6SAndroid Build Coastguard Worker
37*523fa7a6SAndroid Build Coastguard Worker    print("", flush=True)
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker    return generated_tokens
40*523fa7a6SAndroid Build Coastguard Worker
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Workerdef _generate_token_with_kv_cache(args, model, prompt_tokens):
43*523fa7a6SAndroid Build Coastguard Worker    print("Generating tokens:", end="", flush=True)
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker    model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1])
46*523fa7a6SAndroid Build Coastguard Worker    result = model.forward(input_ids=prompt_tokens)
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker    current_token = torch.argmax(result, dim=-1).item()
49*523fa7a6SAndroid Build Coastguard Worker    print(f" {current_token}", end="", flush=True)
50*523fa7a6SAndroid Build Coastguard Worker    generated_tokens = [current_token]
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Worker    while current_token != end_of_text_token and len(generated_tokens) < args.seq_len:
53*523fa7a6SAndroid Build Coastguard Worker        result = model.forward(
54*523fa7a6SAndroid Build Coastguard Worker            input_ids=torch.tensor([[current_token]], dtype=torch.long),
55*523fa7a6SAndroid Build Coastguard Worker        )
56*523fa7a6SAndroid Build Coastguard Worker        current_token = torch.argmax(result, dim=-1).item()
57*523fa7a6SAndroid Build Coastguard Worker        print(f" {current_token}", end="", flush=True)
58*523fa7a6SAndroid Build Coastguard Worker        generated_tokens.append(current_token)
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker    print("", flush=True)
61*523fa7a6SAndroid Build Coastguard Worker
62*523fa7a6SAndroid Build Coastguard Worker    return generated_tokens
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker
65*523fa7a6SAndroid Build Coastguard Workerdef main(args):
66*523fa7a6SAndroid Build Coastguard Worker    seed = 42
67*523fa7a6SAndroid Build Coastguard Worker    torch.manual_seed(seed)
68*523fa7a6SAndroid Build Coastguard Worker    model_name = "microsoft/Phi-3-mini-4k-instruct"
69*523fa7a6SAndroid Build Coastguard Worker    model = Phi3ForCausalLM.from_pretrained(model_name)
70*523fa7a6SAndroid Build Coastguard Worker    tokenizer = AutoTokenizer.from_pretrained(model_name)
71*523fa7a6SAndroid Build Coastguard Worker
72*523fa7a6SAndroid Build Coastguard Worker    tokens = tokenizer.encode(args.prompt, return_tensors="pt")
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard Worker    start = time.time()
75*523fa7a6SAndroid Build Coastguard Worker    generated_tokens = (
76*523fa7a6SAndroid Build Coastguard Worker        _generate_token_with_kv_cache(args, model, tokens)
77*523fa7a6SAndroid Build Coastguard Worker        if args.use_kv_cache
78*523fa7a6SAndroid Build Coastguard Worker        else _generate_token(args, model, tokens)
79*523fa7a6SAndroid Build Coastguard Worker    )
80*523fa7a6SAndroid Build Coastguard Worker    end = time.time()
81*523fa7a6SAndroid Build Coastguard Worker
82*523fa7a6SAndroid Build Coastguard Worker    print(
83*523fa7a6SAndroid Build Coastguard Worker        "Generated response: \n {}".format(
84*523fa7a6SAndroid Build Coastguard Worker            tokenizer.decode(
85*523fa7a6SAndroid Build Coastguard Worker                generated_tokens,
86*523fa7a6SAndroid Build Coastguard Worker                skip_special_tokens=True,
87*523fa7a6SAndroid Build Coastguard Worker                clean_up_tokenization_spaces=False,
88*523fa7a6SAndroid Build Coastguard Worker            )
89*523fa7a6SAndroid Build Coastguard Worker        ),
90*523fa7a6SAndroid Build Coastguard Worker        flush=True,
91*523fa7a6SAndroid Build Coastguard Worker    )
92*523fa7a6SAndroid Build Coastguard Worker    print(f"Time spent: {end - start}", flush=True)
93*523fa7a6SAndroid Build Coastguard Worker
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
96*523fa7a6SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser()
97*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
98*523fa7a6SAndroid Build Coastguard Worker        "-s",
99*523fa7a6SAndroid Build Coastguard Worker        "--seq_len",
100*523fa7a6SAndroid Build Coastguard Worker        type=int,
101*523fa7a6SAndroid Build Coastguard Worker        default=128,
102*523fa7a6SAndroid Build Coastguard Worker        help="Maximum number of tokens to generate",
103*523fa7a6SAndroid Build Coastguard Worker    )
104*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
105*523fa7a6SAndroid Build Coastguard Worker        "-kv",
106*523fa7a6SAndroid Build Coastguard Worker        "--use_kv_cache",
107*523fa7a6SAndroid Build Coastguard Worker        default=False,
108*523fa7a6SAndroid Build Coastguard Worker        action="store_true",
109*523fa7a6SAndroid Build Coastguard Worker        help="Whether or not to use KV cache",
110*523fa7a6SAndroid Build Coastguard Worker    )
111*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
112*523fa7a6SAndroid Build Coastguard Worker        "-p",
113*523fa7a6SAndroid Build Coastguard Worker        "--prompt",
114*523fa7a6SAndroid Build Coastguard Worker        type=str,
115*523fa7a6SAndroid Build Coastguard Worker        default="Tell me a story",
116*523fa7a6SAndroid Build Coastguard Worker        help="Prompt as input for the model",
117*523fa7a6SAndroid Build Coastguard Worker    )
118*523fa7a6SAndroid Build Coastguard Worker    main(parser.parse_args())
119