1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import sys 9 10import torch 11from executorch.examples.models.llava.image_util import prepare_image 12from executorch.examples.models.llava.model import LlavaModel 13from executorch.extension.pybindings.portable_lib import _load_for_executorch 14from PIL import Image 15 16# Custom ops has to be loaded after portable_lib. 17from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip 18from executorch.kernels import quantized # noqa # usort: skip 19 20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 21logging.basicConfig(level=logging.DEBUG, format=FORMAT) 22 23 24def main(): 25 args = sys.argv[1:] 26 if len(args) == 0: 27 print( 28 "Usage: python test_pte.py <model_path> <image_path?>. If no image, will use default image." 29 ) 30 sys.exit(1) 31 32 llava_module = _load_for_executorch(args[0]) 33 34 llava_model = LlavaModel() 35 36 prompt_before_image, resized, prompt_after_image = ( 37 llava_model.get_inputs_for_prefill() 38 ) 39 if len(args) == 2: 40 image_path = args[1] 41 image = Image.open(image_path) 42 resized = prepare_image(image, target_h=336, target_w=336) 43 44 start_pos = 0 45 # pte prefill prompt before img 46 pte_embeds_before_img = llava_module.run_method( 47 "token_embedding", (prompt_before_image,) 48 )[0] 49 pte_prefill_before_img = llava_module.run_method( 50 "text_model", 51 (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img), 52 )[0] 53 print(pte_prefill_before_img) 54 55 start_pos += prompt_before_image.shape[1] 56 57 # pte prefill image 58 logging.warning("Image encoder started") 59 pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0] 60 logging.warning("Image encoder finished") 61 logging.warning("Image token prefill started") 62 pte_prefill_img = llava_module.run_method( 63 "text_model", 64 ( 65 torch.tensor([start_pos], dtype=torch.int64), 66 pte_embeds_img, 67 ), 68 )[0] 69 logging.warning("Image token prefill finished") 70 print(pte_prefill_img) 71 72 start_pos += pte_embeds_img.shape[1] 73 74 # pte prefill prompt after img 75 logging.warning("Text token prefill started") 76 pte_embeds_after_img = llava_module.run_method( 77 "token_embedding", (prompt_after_image,) 78 )[0] 79 pte_prefill_after_img = llava_module.run_method( 80 "text_model", 81 (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img), 82 )[0] 83 logging.warning("Text token prefill finished") 84 print(pte_prefill_after_img) 85 86 # being tested, using llama_transformer 87 new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()] 88 for i in range(4): 89 print(i, llava_model.tokenizer.decode(new_tokens[i])) 90 token_embeds = llava_module.run_method( 91 "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),) 92 )[0] 93 logits = llava_module.run_method( 94 "text_model", 95 (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds), 96 )[0] 97 new_tokens.append(torch.argmax(logits[..., -1, :]).item()) 98 99 outputs = llava_model.tokenizer.batch_decode( 100 torch.tensor([new_tokens]), skip_special_tokens=True 101 )[0].strip() 102 print(outputs) 103 104 105if __name__ == "__main__": 106 main() 107