1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8# Script to run phi-3-mini model in eager mode. 9 10import argparse 11import time 12 13import torch 14 15from transformers import AutoTokenizer, Phi3ForCausalLM 16 17from .phi_3_mini import Phi3Mini 18 19end_of_text_token = 32000 20 21 22def _generate_token(args, model, prompt_tokens): 23 current_token = 0 24 generated_tokens = [] 25 26 print("Generating tokens:", end="", flush=True) 27 28 while current_token != end_of_text_token and len(generated_tokens) < args.seq_len: 29 outputs = model.forward(input_ids=prompt_tokens) 30 current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item() 31 print(f" {current_token}", end="", flush=True) 32 generated_tokens.append(current_token) 33 prompt_tokens = torch.cat( 34 [prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1 35 ) 36 37 print("", flush=True) 38 39 return generated_tokens 40 41 42def _generate_token_with_kv_cache(args, model, prompt_tokens): 43 print("Generating tokens:", end="", flush=True) 44 45 model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1]) 46 result = model.forward(input_ids=prompt_tokens) 47 48 current_token = torch.argmax(result, dim=-1).item() 49 print(f" {current_token}", end="", flush=True) 50 generated_tokens = [current_token] 51 52 while current_token != end_of_text_token and len(generated_tokens) < args.seq_len: 53 result = model.forward( 54 input_ids=torch.tensor([[current_token]], dtype=torch.long), 55 ) 56 current_token = torch.argmax(result, dim=-1).item() 57 print(f" {current_token}", end="", flush=True) 58 generated_tokens.append(current_token) 59 60 print("", flush=True) 61 62 return generated_tokens 63 64 65def main(args): 66 seed = 42 67 torch.manual_seed(seed) 68 model_name = "microsoft/Phi-3-mini-4k-instruct" 69 model = Phi3ForCausalLM.from_pretrained(model_name) 70 tokenizer = AutoTokenizer.from_pretrained(model_name) 71 72 tokens = tokenizer.encode(args.prompt, return_tensors="pt") 73 74 start = time.time() 75 generated_tokens = ( 76 _generate_token_with_kv_cache(args, model, tokens) 77 if args.use_kv_cache 78 else _generate_token(args, model, tokens) 79 ) 80 end = time.time() 81 82 print( 83 "Generated response: \n {}".format( 84 tokenizer.decode( 85 generated_tokens, 86 skip_special_tokens=True, 87 clean_up_tokenization_spaces=False, 88 ) 89 ), 90 flush=True, 91 ) 92 print(f"Time spent: {end - start}", flush=True) 93 94 95if __name__ == "__main__": 96 parser = argparse.ArgumentParser() 97 parser.add_argument( 98 "-s", 99 "--seq_len", 100 type=int, 101 default=128, 102 help="Maximum number of tokens to generate", 103 ) 104 parser.add_argument( 105 "-kv", 106 "--use_kv_cache", 107 default=False, 108 action="store_true", 109 help="Whether or not to use KV cache", 110 ) 111 parser.add_argument( 112 "-p", 113 "--prompt", 114 type=str, 115 default="Tell me a story", 116 help="Prompt as input for the model", 117 ) 118 main(parser.parse_args()) 119