1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker 8*523fa7a6SAndroid Build Coastguard Worker# Script to run phi-3-mini model in eager mode. 9*523fa7a6SAndroid Build Coastguard Worker 10*523fa7a6SAndroid Build Coastguard Workerimport argparse 11*523fa7a6SAndroid Build Coastguard Workerimport time 12*523fa7a6SAndroid Build Coastguard Worker 13*523fa7a6SAndroid Build Coastguard Workerimport torch 14*523fa7a6SAndroid Build Coastguard Worker 15*523fa7a6SAndroid Build Coastguard Workerfrom transformers import AutoTokenizer, Phi3ForCausalLM 16*523fa7a6SAndroid Build Coastguard Worker 17*523fa7a6SAndroid Build Coastguard Workerfrom .phi_3_mini import Phi3Mini 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Workerend_of_text_token = 32000 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Worker 22*523fa7a6SAndroid Build Coastguard Workerdef _generate_token(args, model, prompt_tokens): 23*523fa7a6SAndroid Build Coastguard Worker current_token = 0 24*523fa7a6SAndroid Build Coastguard Worker generated_tokens = [] 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Worker print("Generating tokens:", end="", flush=True) 27*523fa7a6SAndroid Build Coastguard Worker 28*523fa7a6SAndroid Build Coastguard Worker while current_token != end_of_text_token and len(generated_tokens) < args.seq_len: 29*523fa7a6SAndroid Build Coastguard Worker outputs = model.forward(input_ids=prompt_tokens) 30*523fa7a6SAndroid Build Coastguard Worker current_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).item() 31*523fa7a6SAndroid Build Coastguard Worker print(f" {current_token}", end="", flush=True) 32*523fa7a6SAndroid Build Coastguard Worker generated_tokens.append(current_token) 33*523fa7a6SAndroid Build Coastguard Worker prompt_tokens = torch.cat( 34*523fa7a6SAndroid Build Coastguard Worker [prompt_tokens, torch.tensor([[current_token]], dtype=torch.long)], dim=-1 35*523fa7a6SAndroid Build Coastguard Worker ) 36*523fa7a6SAndroid Build Coastguard Worker 37*523fa7a6SAndroid Build Coastguard Worker print("", flush=True) 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker return generated_tokens 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard Workerdef _generate_token_with_kv_cache(args, model, prompt_tokens): 43*523fa7a6SAndroid Build Coastguard Worker print("Generating tokens:", end="", flush=True) 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker model = Phi3Mini(model, 1, args.seq_len + prompt_tokens.shape[-1]) 46*523fa7a6SAndroid Build Coastguard Worker result = model.forward(input_ids=prompt_tokens) 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker current_token = torch.argmax(result, dim=-1).item() 49*523fa7a6SAndroid Build Coastguard Worker print(f" {current_token}", end="", flush=True) 50*523fa7a6SAndroid Build Coastguard Worker generated_tokens = [current_token] 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Worker while current_token != end_of_text_token and len(generated_tokens) < args.seq_len: 53*523fa7a6SAndroid Build Coastguard Worker result = model.forward( 54*523fa7a6SAndroid Build Coastguard Worker input_ids=torch.tensor([[current_token]], dtype=torch.long), 55*523fa7a6SAndroid Build Coastguard Worker ) 56*523fa7a6SAndroid Build Coastguard Worker current_token = torch.argmax(result, dim=-1).item() 57*523fa7a6SAndroid Build Coastguard Worker print(f" {current_token}", end="", flush=True) 58*523fa7a6SAndroid Build Coastguard Worker generated_tokens.append(current_token) 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Worker print("", flush=True) 61*523fa7a6SAndroid Build Coastguard Worker 62*523fa7a6SAndroid Build Coastguard Worker return generated_tokens 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker 65*523fa7a6SAndroid Build Coastguard Workerdef main(args): 66*523fa7a6SAndroid Build Coastguard Worker seed = 42 67*523fa7a6SAndroid Build Coastguard Worker torch.manual_seed(seed) 68*523fa7a6SAndroid Build Coastguard Worker model_name = "microsoft/Phi-3-mini-4k-instruct" 69*523fa7a6SAndroid Build Coastguard Worker model = Phi3ForCausalLM.from_pretrained(model_name) 70*523fa7a6SAndroid Build Coastguard Worker tokenizer = AutoTokenizer.from_pretrained(model_name) 71*523fa7a6SAndroid Build Coastguard Worker 72*523fa7a6SAndroid Build Coastguard Worker tokens = tokenizer.encode(args.prompt, return_tensors="pt") 73*523fa7a6SAndroid Build Coastguard Worker 74*523fa7a6SAndroid Build Coastguard Worker start = time.time() 75*523fa7a6SAndroid Build Coastguard Worker generated_tokens = ( 76*523fa7a6SAndroid Build Coastguard Worker _generate_token_with_kv_cache(args, model, tokens) 77*523fa7a6SAndroid Build Coastguard Worker if args.use_kv_cache 78*523fa7a6SAndroid Build Coastguard Worker else _generate_token(args, model, tokens) 79*523fa7a6SAndroid Build Coastguard Worker ) 80*523fa7a6SAndroid Build Coastguard Worker end = time.time() 81*523fa7a6SAndroid Build Coastguard Worker 82*523fa7a6SAndroid Build Coastguard Worker print( 83*523fa7a6SAndroid Build Coastguard Worker "Generated response: \n {}".format( 84*523fa7a6SAndroid Build Coastguard Worker tokenizer.decode( 85*523fa7a6SAndroid Build Coastguard Worker generated_tokens, 86*523fa7a6SAndroid Build Coastguard Worker skip_special_tokens=True, 87*523fa7a6SAndroid Build Coastguard Worker clean_up_tokenization_spaces=False, 88*523fa7a6SAndroid Build Coastguard Worker ) 89*523fa7a6SAndroid Build Coastguard Worker ), 90*523fa7a6SAndroid Build Coastguard Worker flush=True, 91*523fa7a6SAndroid Build Coastguard Worker ) 92*523fa7a6SAndroid Build Coastguard Worker print(f"Time spent: {end - start}", flush=True) 93*523fa7a6SAndroid Build Coastguard Worker 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__": 96*523fa7a6SAndroid Build Coastguard Worker parser = argparse.ArgumentParser() 97*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 98*523fa7a6SAndroid Build Coastguard Worker "-s", 99*523fa7a6SAndroid Build Coastguard Worker "--seq_len", 100*523fa7a6SAndroid Build Coastguard Worker type=int, 101*523fa7a6SAndroid Build Coastguard Worker default=128, 102*523fa7a6SAndroid Build Coastguard Worker help="Maximum number of tokens to generate", 103*523fa7a6SAndroid Build Coastguard Worker ) 104*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 105*523fa7a6SAndroid Build Coastguard Worker "-kv", 106*523fa7a6SAndroid Build Coastguard Worker "--use_kv_cache", 107*523fa7a6SAndroid Build Coastguard Worker default=False, 108*523fa7a6SAndroid Build Coastguard Worker action="store_true", 109*523fa7a6SAndroid Build Coastguard Worker help="Whether or not to use KV cache", 110*523fa7a6SAndroid Build Coastguard Worker ) 111*523fa7a6SAndroid Build Coastguard Worker parser.add_argument( 112*523fa7a6SAndroid Build Coastguard Worker "-p", 113*523fa7a6SAndroid Build Coastguard Worker "--prompt", 114*523fa7a6SAndroid Build Coastguard Worker type=str, 115*523fa7a6SAndroid Build Coastguard Worker default="Tell me a story", 116*523fa7a6SAndroid Build Coastguard Worker help="Prompt as input for the model", 117*523fa7a6SAndroid Build Coastguard Worker ) 118*523fa7a6SAndroid Build Coastguard Worker main(parser.parse_args()) 119