xref: /aosp_15_r20/external/executorch/examples/models/llava/test/test_pte.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import sys
9
10import torch
11from executorch.examples.models.llava.image_util import prepare_image
12from executorch.examples.models.llava.model import LlavaModel
13from executorch.extension.pybindings.portable_lib import _load_for_executorch
14from PIL import Image
15
16# Custom ops has to be loaded after portable_lib.
17from executorch.extension.llm.custom_ops import sdpa_with_kv_cache  # noqa # usort: skip
18from executorch.kernels import quantized  # noqa # usort: skip
19
20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21logging.basicConfig(level=logging.DEBUG, format=FORMAT)
22
23
24def main():
25    args = sys.argv[1:]
26    if len(args) == 0:
27        print(
28            "Usage: python test_pte.py <model_path> <image_path?>. If no image, will use default image."
29        )
30        sys.exit(1)
31
32    llava_module = _load_for_executorch(args[0])
33
34    llava_model = LlavaModel()
35
36    prompt_before_image, resized, prompt_after_image = (
37        llava_model.get_inputs_for_prefill()
38    )
39    if len(args) == 2:
40        image_path = args[1]
41        image = Image.open(image_path)
42        resized = prepare_image(image, target_h=336, target_w=336)
43
44    start_pos = 0
45    # pte prefill prompt before img
46    pte_embeds_before_img = llava_module.run_method(
47        "token_embedding", (prompt_before_image,)
48    )[0]
49    pte_prefill_before_img = llava_module.run_method(
50        "text_model",
51        (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
52    )[0]
53    print(pte_prefill_before_img)
54
55    start_pos += prompt_before_image.shape[1]
56
57    # pte prefill image
58    logging.warning("Image encoder started")
59    pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
60    logging.warning("Image encoder finished")
61    logging.warning("Image token prefill started")
62    pte_prefill_img = llava_module.run_method(
63        "text_model",
64        (
65            torch.tensor([start_pos], dtype=torch.int64),
66            pte_embeds_img,
67        ),
68    )[0]
69    logging.warning("Image token prefill finished")
70    print(pte_prefill_img)
71
72    start_pos += pte_embeds_img.shape[1]
73
74    # pte prefill prompt after img
75    logging.warning("Text token prefill started")
76    pte_embeds_after_img = llava_module.run_method(
77        "token_embedding", (prompt_after_image,)
78    )[0]
79    pte_prefill_after_img = llava_module.run_method(
80        "text_model",
81        (torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
82    )[0]
83    logging.warning("Text token prefill finished")
84    print(pte_prefill_after_img)
85
86    # being tested, using llama_transformer
87    new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()]
88    for i in range(4):
89        print(i, llava_model.tokenizer.decode(new_tokens[i]))
90        token_embeds = llava_module.run_method(
91            "token_embedding", (torch.tensor([[new_tokens[i]]], dtype=torch.int64),)
92        )[0]
93        logits = llava_module.run_method(
94            "text_model",
95            (torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
96        )[0]
97        new_tokens.append(torch.argmax(logits[..., -1, :]).item())
98
99    outputs = llava_model.tokenizer.batch_decode(
100        torch.tensor([new_tokens]), skip_special_tokens=True
101    )[0].strip()
102    print(outputs)
103
104
105if __name__ == "__main__":
106    main()
107